Skip to content

Commit

Permalink
feat: teach VortexExpr to dtype (#1811)
Browse files Browse the repository at this point in the history
Not entirely clear what to do about non-nullable extension types. The
value should never be observed, but it also seems bad to create a Scalar
whose value might violate the assumptions of the extension type.
  • Loading branch information
danking authored Jan 10, 2025
1 parent e5deba0 commit f97c0cd
Show file tree
Hide file tree
Showing 14 changed files with 340 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions vortex-array/src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,15 @@ impl<T: AsRef<ArrayData>> ArrayDType for T {
}
}

impl ArrayData {
pub fn into_dtype(self) -> DType {
match self.0 {
InnerArrayData::Owned(d) => d.dtype,
InnerArrayData::Viewed(v) => v.dtype,
}
}
}

impl<T: AsRef<ArrayData>> ArrayLen for T {
fn len(&self) -> usize {
self.as_ref().len()
Expand Down
5 changes: 5 additions & 0 deletions vortex-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ vortex-error = { workspace = true }
vortex-proto = { workspace = true, optional = true }
vortex-scalar = { workspace = true }

[dev-dependencies]
vortex-expr = { path = ".", features = ["test-harness"] }


[features]
datafusion = [
"dep:datafusion-expr",
Expand All @@ -48,3 +52,4 @@ proto = [
"vortex-proto/expr",
]
serde = ["dep:serde", "vortex-dtype/serde", "vortex-scalar/serde"]
test-harness = []
74 changes: 73 additions & 1 deletion vortex-expr/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl VortexExpr for BinaryExpr {
self
}

fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
let lhs = self.lhs.evaluate(batch)?;
let rhs = self.rhs.evaluate(batch)?;

Expand Down Expand Up @@ -257,3 +257,75 @@ pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
BinaryExpr::new_expr(lhs, Operator::And, rhs)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use vortex_dtype::{DType, Nullability};

use crate::{and, col, eq, gt, gt_eq, lt, lt_eq, not_eq, or, test_harness, VortexExpr};

#[test]
fn dtype() {
let dtype = test_harness::struct_dtype();
let bool1: Arc<dyn VortexExpr> = col("bool1");
let bool2: Arc<dyn VortexExpr> = col("bool2");
assert_eq!(
and(bool1.clone(), bool2.clone())
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::NonNullable)
);
assert_eq!(
or(bool1.clone(), bool2.clone())
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::NonNullable)
);

let col1: Arc<dyn VortexExpr> = col("col1");
let col2: Arc<dyn VortexExpr> = col("col2");

assert_eq!(
eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
DType::Bool(Nullability::Nullable)
);
assert_eq!(
not_eq(col1.clone(), col2.clone())
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::Nullable)
);
assert_eq!(
gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
DType::Bool(Nullability::Nullable)
);
assert_eq!(
gt_eq(col1.clone(), col2.clone())
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::Nullable)
);
assert_eq!(
lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
DType::Bool(Nullability::Nullable)
);
assert_eq!(
lt_eq(col1.clone(), col2.clone())
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::Nullable)
);

assert_eq!(
or(
lt(col1.clone(), col2.clone()),
not_eq(col1.clone(), col2.clone())
)
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::Nullable)
);
}
}
24 changes: 23 additions & 1 deletion vortex-expr/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ impl VortexExpr for Column {
fn as_any(&self) -> &dyn Any {
self
}
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {

fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
batch
.clone()
.as_struct_array()
.ok_or_else(|| {
vortex_err!(
Expand All @@ -80,3 +82,23 @@ impl VortexExpr for Column {
self
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::{DType, Nullability, PType};

use crate::{col, test_harness};

#[test]
fn dtype() {
let dtype = test_harness::struct_dtype();
assert_eq!(
col("a").return_dtype(&dtype).unwrap(),
DType::Primitive(PType::I32, Nullability::NonNullable)
);
assert_eq!(
col(1).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::U16, Nullability::Nullable)
);
}
}
3 changes: 2 additions & 1 deletion vortex-expr/src/get_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ impl VortexExpr for GetItem {
fn as_any(&self) -> &dyn Any {
self
}
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {

fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
let child = self.child.evaluate(batch)?;
child
.as_struct_array()
Expand Down
14 changes: 13 additions & 1 deletion vortex-expr/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl VortexExpr for Identity {
self
}

fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
Ok(batch.clone())
}

Expand All @@ -45,3 +45,15 @@ impl VortexExpr for Identity {
pub fn ident() -> ExprRef {
Identity::new_expr()
}

#[cfg(test)]
mod tests {
use crate::{ident, test_harness};

#[test]
fn dtype() {
let dtype = test_harness::struct_dtype();
assert_eq!(ident().return_dtype(&dtype).unwrap(), dtype);
assert_eq!(ident().return_dtype(&dtype).unwrap(), dtype);
}
}
53 changes: 50 additions & 3 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ pub use project::*;
pub use row_filter::*;
pub use select::*;
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::ArrayData;
use vortex_dtype::Field;
use vortex_array::{ArrayDType as _, ArrayData, Canonical, IntoArrayData as _};
use vortex_dtype::{DType, Field};
use vortex_error::{VortexResult, VortexUnwrap};

use crate::traversal::{Node, ReferenceCollector};
Expand All @@ -49,11 +49,30 @@ pub trait VortexExpr: Debug + Send + Sync + DynEq + DynHash + Display {
fn as_any(&self) -> &dyn Any;

/// Compute result of expression on given batch producing a new batch
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData>;
///
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
let result = self.unchecked_evaluate(batch)?;
debug_assert_eq!(result.dtype(), &self.return_dtype(batch.dtype())?);
Ok(result)
}

/// Compute result of expression on given batch producing a new batch
///
/// "Unchecked" means that this function lacks a debug assertion that the returned array matches
/// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
/// includes such an assertion.
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData>;

fn children(&self) -> Vec<&ExprRef>;

fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;

/// Compute the type of the array returned by [VortexExpr::evaluate].
fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
let empty = Canonical::empty(scope_dtype)?.into_array();
self.unchecked_evaluate(&empty)
.map(|array| array.into_dtype())
}
}

pub trait VortexExprExt {
Expand Down Expand Up @@ -112,6 +131,34 @@ impl Eq for dyn VortexExpr {}

dyn_hash::hash_trait_object!(VortexExpr);

#[cfg(feature = "test-harness")]
pub mod test_harness {
use vortex_dtype::{DType, Nullability, PType, StructDType};

pub fn struct_dtype() -> DType {
DType::Struct(
StructDType::new(
[
"a".into(),
"col1".into(),
"col2".into(),
"bool1".into(),
"bool2".into(),
]
.into(),
vec![
DType::Primitive(PType::I32, Nullability::NonNullable),
DType::Primitive(PType::U16, Nullability::Nullable),
DType::Primitive(PType::U16, Nullability::Nullable),
DType::Bool(Nullability::NonNullable),
DType::Bool(Nullability::NonNullable),
],
),
Nullability::NonNullable,
)
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::{DType, Field, Nullability, PType, StructDType};
Expand Down
15 changes: 13 additions & 2 deletions vortex-expr/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl VortexExpr for Like {
self
}

fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
let child = self.child().evaluate(batch)?;
let pattern = self.pattern().evaluate(batch)?;
like(
Expand Down Expand Up @@ -102,8 +102,9 @@ impl PartialEq for Like {
mod tests {
use vortex_array::array::BoolArray;
use vortex_array::IntoArrayVariant;
use vortex_dtype::{DType, Nullability};

use crate::{ident, not};
use crate::{ident, lit, not, Like};

#[test]
fn invert_booleans() {
Expand All @@ -121,4 +122,14 @@ mod tests {
vec![true, false, true, true, false, false]
);
}

#[test]
fn dtype() {
let dtype = DType::Utf8(Nullability::NonNullable);
let like_expr = Like::new_expr(ident(), lit("%test%"), false, false);
assert_eq!(
like_expr.return_dtype(&dtype).unwrap(),
DType::Bool(Nullability::NonNullable)
);
}
}
64 changes: 63 additions & 1 deletion vortex-expr/src/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl VortexExpr for Literal {
self
}

fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
fn unchecked_evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
}

Expand Down Expand Up @@ -72,3 +72,65 @@ impl VortexExpr for Literal {
pub fn lit(value: impl Into<Scalar>) -> ExprRef {
Literal::new_expr(value.into())
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use vortex_dtype::{DType, Nullability, PType, StructDType};
use vortex_scalar::Scalar;

use crate::{lit, test_harness};

#[test]
fn dtype() {
let dtype = test_harness::struct_dtype();

assert_eq!(
lit(10).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::I32, Nullability::NonNullable)
);
assert_eq!(
lit(0_u8).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::U8, Nullability::NonNullable)
);
assert_eq!(
lit(0.0_f32).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::F32, Nullability::NonNullable)
);
assert_eq!(
lit(i64::MAX).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::I64, Nullability::NonNullable)
);
assert_eq!(
lit(true).return_dtype(&dtype).unwrap(),
DType::Bool(Nullability::NonNullable)
);
assert_eq!(
lit(Scalar::null(DType::Bool(Nullability::Nullable)))
.return_dtype(&dtype)
.unwrap(),
DType::Bool(Nullability::Nullable)
);

let sdtype = DType::Struct(
StructDType::new(
Arc::from([Arc::from("dog"), Arc::from("cat")]),
vec![
DType::Primitive(PType::U32, Nullability::NonNullable),
DType::Utf8(Nullability::NonNullable),
],
),
Nullability::NonNullable,
);
assert_eq!(
lit(Scalar::struct_(
sdtype.clone(),
vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
))
.return_dtype(&dtype)
.unwrap(),
sdtype
);
}
}
Loading

0 comments on commit f97c0cd

Please sign in to comment.