From 9ef11072114d17392e84dc2102e57e61962818d5 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 28 Jan 2025 19:33:37 +0000 Subject: [PATCH] Refactor Vortex Mask (#2101) We now explicitly expose an AllTrue, AllFalse, and Values (mixed) variant. --- Cargo.lock | 1 + encodings/alp/src/alp/array.rs | 5 +- encodings/alp/src/alp_rd/array.rs | 11 +- encodings/bytebool/Cargo.toml | 1 + encodings/bytebool/src/array.rs | 5 +- encodings/bytebool/src/compute.rs | 19 +- encodings/datetime-parts/src/array.rs | 13 +- encodings/dict/src/array.rs | 8 +- .../fastlanes/src/bitpacking/compress.rs | 1 - .../src/bitpacking/compute/filter.rs | 34 +- encodings/fastlanes/src/bitpacking/mod.rs | 5 +- encodings/fastlanes/src/delta/mod.rs | 5 +- encodings/fastlanes/src/for/compress.rs | 11 +- encodings/fastlanes/src/for/mod.rs | 5 +- encodings/fsst/src/array.rs | 5 +- encodings/runend/benches/run_end_filter.rs | 5 +- encodings/runend/src/array.rs | 16 +- encodings/runend/src/compress.rs | 19 +- encodings/runend/src/compute/filter.rs | 59 +- encodings/runend/src/statistics.rs | 19 +- encodings/sparse/src/lib.rs | 22 +- encodings/zigzag/src/array.rs | 5 +- .../src/array/bool/compute/fill_forward.rs | 56 +- vortex-array/src/array/bool/compute/filter.rs | 29 +- vortex-array/src/array/bool/mod.rs | 5 +- vortex-array/src/array/bool/stats.rs | 11 +- vortex-array/src/array/chunked/canonical.rs | 4 +- .../src/array/chunked/compute/filter.rs | 40 +- vortex-array/src/array/chunked/mod.rs | 5 +- .../src/array/constant/compute/mod.rs | 4 +- vortex-array/src/array/constant/mod.rs | 9 +- vortex-array/src/array/extension/mod.rs | 5 +- vortex-array/src/array/list/mod.rs | 5 +- vortex-array/src/array/null/compute.rs | 7 +- vortex-array/src/array/null/mod.rs | 7 +- .../src/array/primitive/compute/cast.rs | 2 +- .../src/array/primitive/compute/fill.rs | 69 ++- .../src/array/primitive/compute/fill_null.rs | 8 +- .../src/array/primitive/compute/filter.rs | 37 +- vortex-array/src/array/primitive/mod.rs | 5 +- vortex-array/src/array/primitive/stats.rs | 11 +- vortex-array/src/array/struct_/compute.rs | 4 +- vortex-array/src/array/struct_/mod.rs | 5 +- vortex-array/src/array/varbin/array.rs | 5 +- vortex-array/src/array/varbin/arrow.rs | 5 +- .../src/array/varbin/compute/filter.rs | 55 +- vortex-array/src/array/varbin/compute/take.rs | 2 +- vortex-array/src/array/varbin/stats.rs | 2 +- vortex-array/src/array/varbinview/mod.rs | 8 +- vortex-array/src/canonical.rs | 10 +- vortex-array/src/compute/compare.rs | 3 +- vortex-array/src/compute/filter.rs | 25 +- vortex-array/src/data/mod.rs | 5 +- vortex-array/src/encoding/opaque.rs | 5 +- vortex-array/src/patches.rs | 41 +- vortex-array/src/validity.rs | 155 ++--- .../src/layouts/chunked/stats_table.rs | 2 +- vortex-mask/src/bitand.rs | 31 +- vortex-mask/src/eq.rs | 24 +- vortex-mask/src/intersect_by_rank.rs | 41 +- vortex-mask/src/iter_bools.rs | 130 +---- vortex-mask/src/lib.rs | 533 ++++++++++-------- .../src/compressors/for.rs | 2 +- vortex-scan/src/range_scan.rs | 2 +- vortex-scan/src/row_mask.rs | 122 +--- 65 files changed, 788 insertions(+), 1022 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d908f3fed..5763a5ebcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5141,6 +5141,7 @@ dependencies = [ "vortex-buffer", "vortex-dtype", "vortex-error", + "vortex-mask", "vortex-scalar", ] diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 8db12e82fd..d2e5d7495b 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -6,7 +6,7 @@ use vortex_array::encoding::ids; use vortex_array::patches::{Patches, PatchesMetadata}; use vortex_array::stats::StatisticsVTable; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -15,6 +15,7 @@ use vortex_array::{ }; use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult}; +use vortex_mask::Mask; use crate::alp::{alp_encode, decompress, Exponents}; @@ -124,7 +125,7 @@ impl ValidityVTable for ALPEncoding { array.encoded().is_valid(index) } - fn logical_validity(&self, array: &ALPArray) -> VortexResult { + fn logical_validity(&self, array: &ALPArray) -> VortexResult { array.encoded().logical_validity() } } diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index 25af414960..593de3c3ce 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -6,13 +6,14 @@ use vortex_array::encoding::ids; use vortex_array::patches::{Patches, PatchesMetadata}; use vortex_array::stats::{StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ impl_encoding, ArrayDType, ArrayData, ArrayLen, Canonical, IntoCanonical, SerdeMetadata, }; use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_bail, VortexExpect, VortexResult}; +use vortex_mask::Mask; use crate::alp_rd::alp_rd_decode; @@ -210,8 +211,7 @@ impl IntoCanonical for ALPRDArray { right_parts.into_buffer_mut::(), self.left_parts_patches(), )?, - self.logical_validity()? - .into_validity(self.dtype().nullability()), + Validity::from_mask(self.logical_validity()?, self.dtype().nullability()), ) } else { PrimitiveArray::new( @@ -222,8 +222,7 @@ impl IntoCanonical for ALPRDArray { right_parts.into_buffer_mut::(), self.left_parts_patches(), )?, - self.logical_validity()? - .into_validity(self.dtype().nullability()), + Validity::from_mask(self.logical_validity()?, self.dtype().nullability()), ) }; @@ -237,7 +236,7 @@ impl ValidityVTable for ALPRDEncoding { array.left_parts().is_valid(index) } - fn logical_validity(&self, array: &ALPRDArray) -> VortexResult { + fn logical_validity(&self, array: &ALPRDArray) -> VortexResult { // Use validity from left_parts array.left_parts().logical_validity() } diff --git a/encodings/bytebool/Cargo.toml b/encodings/bytebool/Cargo.toml index a82d47afc8..425ca74d19 100644 --- a/encodings/bytebool/Cargo.toml +++ b/encodings/bytebool/Cargo.toml @@ -24,6 +24,7 @@ vortex-array = { workspace = true } vortex-buffer = { workspace = true } vortex-dtype = { workspace = true } vortex-error = { workspace = true } +vortex-mask = { workspace = true } vortex-scalar = { workspace = true } [dev-dependencies] diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index f964884503..dbc210b7b9 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -6,13 +6,14 @@ use vortex_array::array::BoolArray; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable}; use vortex_array::variants::{BoolArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{impl_encoding, ArrayLen, Canonical, IntoCanonical, SerdeMetadata}; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_mask::Mask; impl_encoding!( "vortex.bytebool", @@ -116,7 +117,7 @@ impl ValidityVTable for ByteBoolEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult { + fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index c1381e93b9..3f75648524 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,10 +1,11 @@ use num_traits::AsPrimitive; use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn}; -use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity}; +use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{match_each_integer_ptype, Nullability}; use vortex_error::{vortex_err, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::Scalar; use super::{ByteBoolArray, ByteBoolEncoding}; @@ -55,7 +56,7 @@ impl TakeFn for ByteBoolEncoding { // FIXME(ngates): we should be operating over canonical validity, which doesn't // have fallible is_valid function. let arr = match validity { - LogicalValidity::AllValid(_) => { + Mask::AllTrue(_) => { let bools = match_each_integer_ptype!(indices.ptype(), |$I| { indices.as_slice::<$I>() .iter() @@ -68,16 +69,14 @@ impl TakeFn for ByteBoolEncoding { ByteBoolArray::from(bools).into_array() } - LogicalValidity::AllInvalid(_) => { - ByteBoolArray::from(vec![None; indices.len()]).into_array() - } - LogicalValidity::Mask(mask) => { + Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(), + Mask::Values(values) => { let bools = match_each_integer_ptype!(indices.ptype(), |$I| { indices.as_slice::<$I>() .iter() .map(|&idx| { let idx = idx.as_(); - if mask.value(idx) { + if values.value(idx) { Some(bools[idx]) } else { None @@ -101,13 +100,13 @@ impl FillForwardFn for ByteBoolEncoding { return Ok(array.to_array()); } // all valid, but we need to convert to non-nullable - if validity.all_valid() { + if validity.all_true() { return Ok( ByteBoolArray::try_new(array.buffer().clone(), Validity::AllValid)?.into_array(), ); } // all invalid => fill with default value (false) - if validity.all_invalid() { + if validity.all_false() { return Ok( ByteBoolArray::try_from_vec(vec![false; array.len()], Validity::AllValid)? .into_array(), @@ -115,7 +114,7 @@ impl FillForwardFn for ByteBoolEncoding { } let validity = validity - .to_null_buffer()? + .to_null_buffer() .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; let bools = array.as_slice(); diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 2e95de8ef3..4ab946966b 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -6,12 +6,13 @@ use vortex_array::compute::try_cast; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable}; use vortex_array::variants::{ExtensionArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, IntoArrayData, SerdeMetadata}; use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, VortexExpect as _, VortexResult, VortexUnwrap}; +use vortex_mask::Mask; impl_encoding!( "vortex.datetimeparts", @@ -100,10 +101,10 @@ impl DateTimePartsArray { pub fn validity(&self) -> VortexResult { // FIXME(ngates): this function is weird... can we just use logical validity? - Ok(self - .days() - .logical_validity()? - .into_validity(self.dtype().nullability())) + Ok(Validity::from_mask( + self.days().logical_validity()?, + self.dtype().nullability(), + )) } } @@ -140,7 +141,7 @@ impl ValidityVTable for DateTimePartsEncoding { array.days().is_valid(index) } - fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult { + fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult { array.days().logical_validity() } } diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 25eb34d917..ee844833a9 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -6,7 +6,7 @@ use vortex_array::compute::{scalar_at, take}; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -90,7 +90,7 @@ impl ValidityVTable for DictEncoding { array.values().is_valid(values_index) } - fn logical_validity(&self, array: &DictArray) -> VortexResult { + fn logical_validity(&self, array: &DictArray) -> VortexResult { if array.dtype().is_nullable() { let primitive_codes = array.codes().into_primitive()?; match_each_integer_ptype!(primitive_codes.ptype(), |$P| { @@ -99,10 +99,10 @@ impl ValidityVTable for DictEncoding { let is_valid_buffer = BooleanBuffer::collect_bool(is_valid.len(), |idx| { is_valid[idx] != 0 }); - Ok(LogicalValidity::Mask(Mask::from_buffer(is_valid_buffer))) + Ok(Mask::from_buffer(is_valid_buffer)) }) } else { - Ok(LogicalValidity::AllValid(array.len())) + Ok(Mask::AllTrue(array.len())) } } } diff --git a/encodings/fastlanes/src/bitpacking/compress.rs b/encodings/fastlanes/src/bitpacking/compress.rs index 593d97effd..b9232efedd 100644 --- a/encodings/fastlanes/src/bitpacking/compress.rs +++ b/encodings/fastlanes/src/bitpacking/compress.rs @@ -413,7 +413,6 @@ mod test { .unwrap() .to_null_buffer() .unwrap() - .unwrap() .into_inner() .set_indices() .collect::>() diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index 587f8fa307..b666f1af6d 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -6,8 +6,8 @@ use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_buffer::{Buffer, BufferMut}; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; -use vortex_error::VortexResult; -use vortex_mask::{Mask, MaskIter}; +use vortex_error::{VortexExpect, VortexResult}; +use vortex_mask::Mask; use super::chunked_indices; use crate::bitpacking::compute::take::UNPACK_CHUNK_THRESHOLD; @@ -43,17 +43,20 @@ fn filter_primitive( .flatten(); // Short-circuit if the selectivity is high enough. - if mask.selectivity() > 0.8 { + if mask.density() > 0.8 { return filter(array.clone().into_primitive()?.as_ref(), mask) .and_then(|a| a.into_primitive()); } - let values: Buffer = match mask.iter() { - MaskIter::Indices(indices) => { - filter_indices(array, mask.true_count(), indices.iter().copied()) - } - MaskIter::Slices(slices) => filter_slices(array, mask.true_count(), slices.iter().copied()), - }; + let values: Buffer = filter_indices( + array, + mask.true_count(), + mask.values() + .vortex_expect("AllTrue and AllFalse handled by filter fn") + .indices() + .iter() + .copied(), + ); let mut values = PrimitiveArray::new(values, validity).reinterpret_cast(array.ptype()); if let Some(patches) = patches { @@ -111,19 +114,6 @@ fn filter_indices( values.freeze() } -fn filter_slices( - array: &BitPackedArray, - indices_len: usize, - slices: impl Iterator, -) -> Buffer { - // TODO(ngates): do this more efficiently. - filter_indices( - array, - indices_len, - slices.into_iter().flat_map(|(start, end)| start..end), - ) -} - #[cfg(test)] mod test { use vortex_array::array::PrimitiveArray; diff --git a/encodings/fastlanes/src/bitpacking/mod.rs b/encodings/fastlanes/src/bitpacking/mod.rs index 648dc272e1..a7a9012870 100644 --- a/encodings/fastlanes/src/bitpacking/mod.rs +++ b/encodings/fastlanes/src/bitpacking/mod.rs @@ -7,7 +7,7 @@ use vortex_array::encoding::ids; use vortex_array::patches::{Patches, PatchesMetadata}; use vortex_array::stats::{StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable}; use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -16,6 +16,7 @@ use vortex_array::{ use vortex_buffer::ByteBuffer; use vortex_dtype::{DType, NativePType, PType}; use vortex_error::{vortex_bail, vortex_err, VortexExpect as _, VortexResult}; +use vortex_mask::Mask; mod compress; mod compute; @@ -262,7 +263,7 @@ impl ValidityVTable for BitPackedEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &BitPackedArray) -> VortexResult { + fn logical_validity(&self, array: &BitPackedArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/encodings/fastlanes/src/delta/mod.rs b/encodings/fastlanes/src/delta/mod.rs index eac50e7459..4016d0608b 100644 --- a/encodings/fastlanes/src/delta/mod.rs +++ b/encodings/fastlanes/src/delta/mod.rs @@ -5,7 +5,7 @@ use vortex_array::array::PrimitiveArray; use vortex_array::encoding::ids; use vortex_array::stats::{StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable}; use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -15,6 +15,7 @@ use vortex_array::{ use vortex_buffer::Buffer; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult}; +use vortex_mask::Mask; mod compress; mod compute; @@ -243,7 +244,7 @@ impl ValidityVTable for DeltaEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &DeltaArray) -> VortexResult { + fn logical_validity(&self, array: &DeltaArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index df4fc80d24..b0ca126d00 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -1,7 +1,7 @@ use num_traits::{PrimInt, WrappingAdd, WrappingSub}; use vortex_array::array::{ConstantArray, PrimitiveArray}; use vortex_array::stats::{trailing_zeros, ArrayStatistics, Stat}; -use vortex_array::validity::{ArrayValidity, LogicalValidity}; +use vortex_array::validity::ArrayValidity; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_buffer::{Buffer, BufferMut}; @@ -9,6 +9,7 @@ use vortex_dtype::{ match_each_integer_ptype, match_each_unsigned_integer_ptype, DType, NativePType, Nullability, }; use vortex_error::{vortex_err, VortexExpect, VortexResult, VortexUnwrap}; +use vortex_mask::Mask; use vortex_scalar::Scalar; use vortex_sparse::SparseArray; @@ -39,20 +40,20 @@ pub fn for_compress(array: PrimitiveArray) -> VortexResult { } fn encoded_zero( - logical_validity: LogicalValidity, + logical_validity: Mask, nullability: Nullability, ) -> VortexResult { let encoded_ptype = T::PTYPE.to_unsigned(); let zero = match_each_unsigned_integer_ptype!(encoded_ptype, |$T| Scalar::primitive($T::default(), nullability)); Ok(match logical_validity { - LogicalValidity::AllValid(len) => ConstantArray::new(zero, len).into_array(), - LogicalValidity::AllInvalid(len) => ConstantArray::new( + Mask::AllTrue(len) => ConstantArray::new(zero, len).into_array(), + Mask::AllFalse(len) => ConstantArray::new( Scalar::null(DType::Primitive(encoded_ptype, Nullability::Nullable)), len, ) .into_array(), - LogicalValidity::Mask(mask) => { + Mask::Values(mask) => { let len = mask.len(); let valid_indices = mask diff --git a/encodings/fastlanes/src/for/mod.rs b/encodings/fastlanes/src/for/mod.rs index 5b71188102..5b2f80f539 100644 --- a/encodings/fastlanes/src/for/mod.rs +++ b/encodings/fastlanes/src/for/mod.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use vortex_array::encoding::ids; use vortex_array::stats::{StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -13,6 +13,7 @@ use vortex_array::{ }; use vortex_dtype::DType; use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::{PValue, Scalar}; mod compress; @@ -95,7 +96,7 @@ impl ValidityVTable for FoREncoding { array.encoded().is_valid(index) } - fn logical_validity(&self, array: &FoRArray) -> VortexResult { + fn logical_validity(&self, array: &FoRArray) -> VortexResult { array.encoded().logical_validity() } } diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 447f0452ec..3374132d04 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -4,12 +4,13 @@ use vortex_array::array::{VarBinArray, VarBinEncoding}; use vortex_array::encoding::{ids, Encoding}; use vortex_array::stats::{StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable}; use vortex_array::variants::{BinaryArrayTrait, Utf8ArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, IntoCanonical, SerdeMetadata}; use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_bail, VortexExpect, VortexResult}; +use vortex_mask::Mask; impl_encoding!("vortex.fsst", ids::FSST, FSST, SerdeMetadata); @@ -198,7 +199,7 @@ impl ValidityVTable for FSSTEncoding { array.codes().is_valid(index) } - fn logical_validity(&self, array: &FSSTArray) -> VortexResult { + fn logical_validity(&self, array: &FSSTArray) -> VortexResult { array.codes().logical_validity() } } diff --git a/encodings/runend/benches/run_end_filter.rs b/encodings/runend/benches/run_end_filter.rs index 5914861a3f..9961e81157 100644 --- a/encodings/runend/benches/run_end_filter.rs +++ b/encodings/runend/benches/run_end_filter.rs @@ -67,7 +67,10 @@ fn evenly_spaced(c: &mut Criterion) { ), |b| { b.iter(|| { - black_box(take_indices_unchecked(&array, mask.indices()).unwrap()) + black_box( + take_indices_unchecked(&array, mask.values().unwrap().indices()) + .unwrap(), + ) }); }, ); diff --git a/encodings/runend/src/array.rs b/encodings/runend/src/array.rs index 91952eec8f..77f9bbd9b7 100644 --- a/encodings/runend/src/array.rs +++ b/encodings/runend/src/array.rs @@ -8,7 +8,7 @@ use vortex_array::compute::{ use vortex_array::encoding::ids; use vortex_array::stats::{ArrayStatistics, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::{BoolArrayTrait, PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -181,22 +181,20 @@ impl ValidityVTable for RunEndEncoding { array.values().is_valid(physical_idx) } - fn logical_validity(&self, array: &RunEndArray) -> VortexResult { + fn logical_validity(&self, array: &RunEndArray) -> VortexResult { Ok(match array.values().logical_validity()? { - LogicalValidity::AllValid(_) => LogicalValidity::AllValid(array.len()), - LogicalValidity::AllInvalid(_) => LogicalValidity::AllInvalid(array.len()), - LogicalValidity::Mask(validity) => { + Mask::AllTrue(_) => Mask::AllTrue(array.len()), + Mask::AllFalse(_) => Mask::AllFalse(array.len()), + Mask::Values(values) => { let ree_validity = RunEndArray::with_offset_and_length( array.ends(), - validity.into_array(), + values.into_array(), array.offset(), array.len(), ) .vortex_expect("invalid array") .into_array(); - LogicalValidity::Mask(Mask::from_buffer( - ree_validity.into_bool()?.boolean_buffer(), - )) + Mask::from_buffer(ree_validity.into_bool()?.boolean_buffer()) } }) } diff --git a/encodings/runend/src/compress.rs b/encodings/runend/src/compress.rs index 443cbaae95..e9bf2172f1 100644 --- a/encodings/runend/src/compress.rs +++ b/encodings/runend/src/compress.rs @@ -1,12 +1,13 @@ use arrow_buffer::BooleanBufferBuilder; use itertools::Itertools; use vortex_array::array::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray}; -use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity}; +use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; use vortex_buffer::{buffer, Buffer, BufferMut}; use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType, Nullability}; use vortex_error::{VortexExpect, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::iter::trimmed_ends_iter; @@ -175,22 +176,22 @@ pub fn runend_decode_bools( pub fn runend_decode_typed_primitive( run_ends: impl Iterator, values: &[T], - values_validity: LogicalValidity, + values_validity: Mask, values_nullability: Nullability, length: usize, ) -> VortexResult { Ok(match values_validity { - LogicalValidity::AllValid(_) => { + Mask::AllTrue(_) => { let mut decoded: BufferMut = BufferMut::with_capacity(length); for (end, value) in run_ends.zip_eq(values) { decoded.push_n(*value, end - decoded.len()); } PrimitiveArray::new(decoded, values_nullability.into()) } - LogicalValidity::AllInvalid(_) => { + Mask::AllFalse(_) => { PrimitiveArray::new(buffer![T::default(); length], Validity::AllInvalid) } - LogicalValidity::Mask(mask) => { + Mask::Values(mask) => { let mut decoded = BufferMut::with_capacity(length); let mut decoded_validity = BooleanBufferBuilder::new(length); for (end, value) in run_ends.zip_eq( @@ -218,23 +219,23 @@ pub fn runend_decode_typed_primitive( pub fn runend_decode_typed_bool( run_ends: impl Iterator, values: BooleanBuffer, - values_validity: LogicalValidity, + values_validity: Mask, values_nullability: Nullability, length: usize, ) -> VortexResult { Ok(match values_validity { - LogicalValidity::AllValid(_) => { + Mask::AllTrue(_) => { let mut decoded = BooleanBufferBuilder::new(length); for (end, value) in run_ends.zip_eq(values.iter()) { decoded.append_n(end - decoded.len(), value); } BoolArray::new(decoded.finish(), values_nullability) } - LogicalValidity::AllInvalid(_) => { + Mask::AllFalse(_) => { BoolArray::try_new(BooleanBuffer::new_unset(length), Validity::AllInvalid) .vortex_expect("invalid array") } - LogicalValidity::Mask(mask) => { + Mask::Values(mask) => { let mut decoded = BooleanBufferBuilder::new(length); let mut decoded_validity = BooleanBufferBuilder::new(length); for (end, value) in run_ends.zip_eq( diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 2c3da188fe..9bfd349405 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -7,10 +7,10 @@ use vortex_array::array::PrimitiveArray; use vortex_array::compute::{filter, FilterFn}; use vortex_array::validity::Validity; use vortex_array::variants::PrimitiveArrayTrait; -use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; +use vortex_array::{ArrayDType, ArrayData, ArrayLen, Canonical, IntoArrayData, IntoArrayVariant}; use vortex_buffer::buffer_mut; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; -use vortex_error::{VortexResult, VortexUnwrap}; +use vortex_error::{VortexExpect, VortexResult, VortexUnwrap}; use vortex_mask::Mask; use crate::compute::take::take_indices_unchecked; @@ -20,21 +20,32 @@ const FILTER_TAKE_THRESHOLD: f64 = 0.1; impl FilterFn for RunEndEncoding { fn filter(&self, array: &RunEndArray, mask: &Mask) -> VortexResult { - let runs_ratio = mask.true_count() as f64 / array.ends().len() as f64; - - if runs_ratio < FILTER_TAKE_THRESHOLD || mask.true_count() < 25 { - // This strategy is directly proportional to the number of indices. - take_indices_unchecked(array, mask.indices()) - } else { - // This strategy ends up being close to fixed cost based on the number of runs, - // rather than the number of indices. - let primitive_run_ends = array.ends().into_primitive()?; - let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| { - filter_run_end_primitive(primitive_run_ends.as_slice::<$P>(), array.offset() as u64, array.len() as u64, mask)? - }); - let values = filter(&array.values(), &values_mask)?; - - RunEndArray::try_new(run_ends.into_array(), values).map(|a| a.into_array()) + match mask { + Mask::AllTrue(_) => Ok(array.clone().into_array()), + Mask::AllFalse(_) => Ok(Canonical::empty(array.dtype()).into()), + Mask::Values(mask_values) => { + let runs_ratio = mask_values.true_count() as f64 / array.ends().len() as f64; + + if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 { + // This strategy is directly proportional to the number of indices. + take_indices_unchecked(array, mask_values.indices()) + } else { + // This strategy ends up being close to fixed cost based on the number of runs, + // rather than the number of indices. + let primitive_run_ends = array.ends().into_primitive()?; + let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| { + filter_run_end_primitive( + primitive_run_ends.as_slice::<$P>(), + array.offset() as u64, + array.len() as u64, + mask_values.boolean_buffer(), + )? + }); + let values = filter(&array.values(), &values_mask)?; + + RunEndArray::try_new(run_ends.into_array(), values).map(|a| a.into_array()) + } + } } } } @@ -43,7 +54,12 @@ impl FilterFn for RunEndEncoding { pub fn filter_run_end(array: &RunEndArray, mask: &Mask) -> VortexResult { let primitive_run_ends = array.ends().into_primitive()?; let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| { - filter_run_end_primitive(primitive_run_ends.as_slice::<$P>(), array.offset() as u64, array.len() as u64, mask)? + filter_run_end_primitive( + primitive_run_ends.as_slice::<$P>(), + array.offset() as u64, + array.len() as u64, + mask.values().vortex_expect("AllTrue and AllFalse handled by filter fn").boolean_buffer(), + )? }); let values = filter(&array.values(), &values_mask)?; @@ -55,22 +71,21 @@ fn filter_run_end_primitive + AsPrimitiv run_ends: &[R], offset: u64, length: u64, - mask: &Mask, + mask: &BooleanBuffer, ) -> VortexResult<(PrimitiveArray, Mask)> { let mut new_run_ends = buffer_mut![R::zero(); run_ends.len()]; let mut start = 0u64; let mut j = 0; let mut count = R::zero(); - let filter_values = mask.boolean_buffer(); let new_mask: Mask = BooleanBuffer::collect_bool(run_ends.len(), |i| { let mut keep = false; let end = min(run_ends[i].as_() - offset, length); // Safety: predicate must be the same length as the array the ends have been taken from - for pred in (start..end) - .map(|i| unsafe { filter_values.value_unchecked(i.try_into().vortex_unwrap()) }) + for pred in + (start..end).map(|i| unsafe { mask.value_unchecked(i.try_into().vortex_unwrap()) }) { count += >::from(pred); keep |= pred diff --git a/encodings/runend/src/statistics.rs b/encodings/runend/src/statistics.rs index 7609db8945..10c87c029f 100644 --- a/encodings/runend/src/statistics.rs +++ b/encodings/runend/src/statistics.rs @@ -3,11 +3,12 @@ use std::cmp; use arrow_buffer::BooleanBuffer; use itertools::Itertools; use vortex_array::stats::{ArrayStatistics as _, Stat, StatisticsVTable, StatsSet}; -use vortex_array::validity::{ArrayValidity as _, LogicalValidity}; +use vortex_array::validity::ArrayValidity as _; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType as _, ArrayLen as _, IntoArrayVariant as _}; use vortex_dtype::{match_each_unsigned_integer_ptype, DType, NativePType}; use vortex_error::VortexResult; +use vortex_mask::Mask; use vortex_scalar::ScalarValue; use crate::{RunEndArray, RunEndEncoding}; @@ -22,7 +23,7 @@ impl StatisticsVTable for RunEndEncoding { .statistics() .compute_is_sorted() .unwrap_or(false) - && array.logical_validity()?.all_valid(), + && array.logical_validity()?.all_true(), )), Stat::TrueCount => match array.dtype() { DType::Bool(_) => Some(ScalarValue::from(array.true_count()?)), @@ -54,7 +55,7 @@ impl RunEndArray { decompressed_values: BooleanBuffer, ) -> VortexResult { Ok(match self.values().logical_validity()? { - LogicalValidity::AllValid(_) => { + Mask::AllTrue(_) => { let mut begin = self.offset() as u64; decompressed_ends .iter() @@ -68,9 +69,9 @@ impl RunEndArray { }) .sum() } - LogicalValidity::AllInvalid(_) => 0, - LogicalValidity::Mask(mask) => { - let mut is_valid = mask.indices().iter(); + Mask::AllFalse(_) => 0, + Mask::Values(values) => { + let mut is_valid = values.indices().iter(); match is_valid.next() { None => self.len() as u64, Some(&valid_index) => { @@ -103,9 +104,9 @@ impl RunEndArray { fn null_count(&self) -> VortexResult { let ends = self.ends().into_primitive()?; let null_count = match self.values().logical_validity()? { - LogicalValidity::AllValid(_) => 0u64, - LogicalValidity::AllInvalid(_) => self.len() as u64, - LogicalValidity::Mask(mask) => { + Mask::AllTrue(_) => 0u64, + Mask::AllFalse(_) => self.len() as u64, + Mask::Values(mask) => { match_each_unsigned_integer_ptype!(ends.ptype(), |$P| self.null_count_with_array_validity(ends.as_slice::<$P>(), mask.boolean_buffer())) } }; diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index fd937fd8c7..228f71d14a 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -6,7 +6,7 @@ use vortex_array::encoding::ids; use vortex_array::patches::{Patches, PatchesMetadata}; use vortex_array::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -197,7 +197,7 @@ impl ValidityVTable for SparseEncoding { }) } - fn logical_validity(&self, array: &SparseArray) -> VortexResult { + fn logical_validity(&self, array: &SparseArray) -> VortexResult { let indices = array.patches().indices().clone().into_primitive()?; if array.fill_scalar().is_null() { @@ -212,7 +212,7 @@ impl ValidityVTable for SparseEncoding { }); }); - return Ok(LogicalValidity::Mask(Mask::from_buffer(buffer.finish()))); + return Ok(Mask::from_buffer(buffer.finish())); } // If the fill_value is non-null, then the validity is based on the validity of the @@ -226,11 +226,11 @@ impl ValidityVTable for SparseEncoding { .into_iter() .enumerate() .for_each(|(patch_idx, &index)| { - buffer.set_bit(index.try_into().vortex_expect("failed to cast to usize"), values_validity.is_valid(patch_idx)); + buffer.set_bit(index.try_into().vortex_expect("failed to cast to usize"), values_validity.value(patch_idx)); }) }); - Ok(LogicalValidity::Mask(Mask::from_buffer(buffer.finish()))) + Ok(Mask::from_buffer(buffer.finish())) } } @@ -350,11 +350,13 @@ mod test { #[test] pub fn sparse_logical_validity() { let array = sparse_array(nullable_fill()); - let LogicalValidity::Mask(mask) = array.logical_validity().unwrap() else { - unreachable!() - }; assert_eq!( - mask.boolean_buffer().iter().collect_vec(), + array + .logical_validity() + .unwrap() + .to_boolean_buffer() + .iter() + .collect_vec(), [false, false, true, false, false, true, false, false, true, false] ); } @@ -362,7 +364,7 @@ mod test { #[test] fn sparse_logical_validity_non_null_fill() { let array = sparse_array(non_nullable_fill()); - assert!(array.logical_validity().unwrap().all_valid()); + assert!(array.logical_validity().unwrap().all_true()); } #[test] diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index 57c2c58a85..ed3172749c 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -2,7 +2,7 @@ use vortex_array::array::PrimitiveArray; use vortex_array::encoding::ids; use vortex_array::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; use vortex_array::validate::ValidateVTable; -use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, ValidityVTable}; use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -11,6 +11,7 @@ use vortex_array::{ }; use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexExpect as _, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::ScalarValue; use zigzag::ZigZag as ExternalZigZag; @@ -77,7 +78,7 @@ impl ValidityVTable for ZigZagEncoding { array.encoded().is_valid(index) } - fn logical_validity(&self, array: &ZigZagArray) -> VortexResult { + fn logical_validity(&self, array: &ZigZagArray) -> VortexResult { array.encoded().logical_validity() } } diff --git a/vortex-array/src/array/bool/compute/fill_forward.rs b/vortex-array/src/array/bool/compute/fill_forward.rs index 2e0f29bc8d..13fbe1756f 100644 --- a/vortex-array/src/array/bool/compute/fill_forward.rs +++ b/vortex-array/src/array/bool/compute/fill_forward.rs @@ -1,10 +1,12 @@ use arrow_buffer::BooleanBuffer; +use itertools::Itertools; use vortex_dtype::Nullability; use vortex_error::{vortex_err, VortexResult}; +use vortex_mask::AllOr; use crate::array::{BoolArray, BoolEncoding}; use crate::compute::FillForwardFn; -use crate::validity::{ArrayValidity, LogicalValidity, Validity}; +use crate::validity::{ArrayValidity, Validity}; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, ToArrayData}; impl FillForwardFn for BoolEncoding { @@ -16,34 +18,32 @@ impl FillForwardFn for BoolEncoding { return Ok(array.to_array()); } - // all valid, but we need to convert to non-nullable - if validity.all_valid() { - return Ok(BoolArray::new(array.boolean_buffer(), Nullability::Nullable).into_array()); + match validity.boolean_buffer() { + AllOr::All => { + // all valid, but we need to convert to non-nullable + Ok(BoolArray::new(array.boolean_buffer(), Nullability::Nullable).into_array()) + } + AllOr::None => { + // all invalid => fill with default value (false) + Ok( + BoolArray::try_new(BooleanBuffer::new_unset(array.len()), Validity::AllValid)? + .into_array(), + ) + } + AllOr::Some(validity) => { + let bools = array.boolean_buffer(); + let mut last_value = false; + let buffer = BooleanBuffer::from_iter(bools.iter().zip_eq(validity.iter()).map( + |(v, valid)| { + if valid { + last_value = v; + } + last_value + }, + )); + Ok(BoolArray::try_new(buffer, Validity::AllValid)?.into_array()) + } } - // all invalid => fill with default value (false) - if validity.all_invalid() { - return Ok(BoolArray::try_new( - BooleanBuffer::new_unset(array.len()), - Validity::AllValid, - )? - .into_array()); - } - - let validity = validity - .to_null_buffer()? - .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; - - let bools = array.boolean_buffer(); - let mut last_value = false; - let buffer = BooleanBuffer::from_iter(bools.iter().zip(validity.inner().iter()).map( - |(v, valid)| { - if valid { - last_value = v; - } - last_value - }, - )); - Ok(BoolArray::try_new(buffer, Validity::AllValid)?.into_array()) } } diff --git a/vortex-array/src/array/bool/compute/filter.rs b/vortex-array/src/array/bool/compute/filter.rs index f3955cf2ea..14483d60bb 100644 --- a/vortex-array/src/array/bool/compute/filter.rs +++ b/vortex-array/src/array/bool/compute/filter.rs @@ -1,17 +1,30 @@ +use std::sync::Arc; + use arrow_buffer::{bit_util, BooleanBuffer, BooleanBufferBuilder}; use vortex_error::{VortexExpect, VortexResult}; -use vortex_mask::{Mask, MaskIter}; +use vortex_mask::{AllOr, Mask, MaskIter, MaskValues}; use crate::array::{BoolArray, BoolEncoding}; use crate::compute::FilterFn; -use crate::{ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData}; + +/// If the filter density is above 80%, we use slices to filter the array instead of indices. +const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8; impl FilterFn for BoolEncoding { fn filter(&self, array: &BoolArray, mask: &Mask) -> VortexResult { let validity = array.validity().filter(mask)?; - let buffer = match mask.iter() { - MaskIter::Indices(indices) => filter_indices_slice(&array.boolean_buffer(), indices), + let mask_values = mask + .values() + .vortex_expect("AllTrue and AllFalse are handled by filter fn"); + + let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { + MaskIter::Indices(indices) => filter_indices( + &array.boolean_buffer(), + mask.true_count(), + indices.iter().copied(), + ), MaskIter::Slices(slices) => filter_slices( &array.boolean_buffer(), mask.true_count(), @@ -26,14 +39,6 @@ impl FilterFn for BoolEncoding { /// Select indices from a boolean buffer. /// NOTE: it was benchmarked to be faster using collect_bool to index into a slice than to /// pass the indices as an iterator of usize. So we keep this alternate implementation. -fn filter_indices_slice(buffer: &BooleanBuffer, indices: &[usize]) -> BooleanBuffer { - let src = buffer.values().as_ptr(); - let offset = buffer.offset(); - BooleanBuffer::collect_bool(indices.len(), |idx| unsafe { - bit_util::get_bit_raw(src, *indices.get_unchecked(idx) + offset) - }) -} - pub fn filter_indices( buffer: &BooleanBuffer, indices_len: usize, diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index c13717d003..1a4462b9e9 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -9,7 +9,7 @@ use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; use crate::encoding::ids; use crate::stats::StatsSet; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use crate::validity::{Validity, ValidityMetadata, ValidityVTable}; use crate::variants::{BoolArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -23,6 +23,7 @@ mod stats; // Re-export the BooleanBuffer type on our API surface. pub use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; +use vortex_mask::Mask; impl_encoding!("vortex.bool", ids::BOOL, Bool, RkyvMetadata); @@ -211,7 +212,7 @@ impl ValidityVTable for BoolEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &BoolArray) -> VortexResult { + fn logical_validity(&self, array: &BoolArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index c2e67e44ee..6e584e460a 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -4,11 +4,12 @@ use arrow_buffer::BooleanBuffer; use itertools::Itertools; use vortex_dtype::{DType, Nullability}; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::array::{BoolArray, BoolEncoding}; use crate::nbytes::ArrayNBytes; use crate::stats::{Stat, StatisticsVTable, StatsSet}; -use crate::validity::{ArrayValidity, LogicalValidity}; +use crate::validity::ArrayValidity; use crate::{ArrayDType, ArrayLen, IntoArrayVariant}; impl StatisticsVTable for BoolEncoding { @@ -26,10 +27,10 @@ impl StatisticsVTable for BoolEncoding { } match array.logical_validity()? { - LogicalValidity::AllValid(_) => self.compute_statistics(&array.boolean_buffer(), stat), - LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, array.dtype())), - LogicalValidity::Mask(mask) => self.compute_statistics( - &NullableBools(&array.boolean_buffer(), mask.boolean_buffer()), + Mask::AllTrue(_) => self.compute_statistics(&array.boolean_buffer(), stat), + Mask::AllFalse(v) => Ok(StatsSet::nulls(v, array.dtype())), + Mask::Values(values) => self.compute_statistics( + &NullableBools(&array.boolean_buffer(), values.boolean_buffer()), stat, ), } diff --git a/vortex-array/src/array/chunked/canonical.rs b/vortex-array/src/array/chunked/canonical.rs index fbf9845dd7..c512e31f03 100644 --- a/vortex-array/src/array/chunked/canonical.rs +++ b/vortex-array/src/array/chunked/canonical.rs @@ -18,9 +18,7 @@ use crate::{ impl IntoCanonical for ChunkedArray { fn into_canonical(self) -> VortexResult { - let validity = self - .logical_validity()? - .into_validity(self.dtype().nullability()); + let validity = Validity::from_mask(self.logical_validity()?, self.dtype().nullability()); try_canonicalize_chunks(self.chunks().collect(), validity, self.dtype()) } } diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index ef488d017f..a91fe6c4f3 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -1,29 +1,31 @@ +use std::sync::Arc; + use vortex_buffer::BufferMut; use vortex_error::{VortexExpect, VortexResult, VortexUnwrap}; -use vortex_mask::Mask; +use vortex_mask::{Mask, MaskIter, MaskValues}; use crate::array::{ChunkedArray, ChunkedEncoding, PrimitiveArray}; use crate::compute::{filter, take, FilterFn, SearchSorted, SearchSortedSide}; use crate::validity::Validity; -use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoCanonical}; +use crate::{ArrayDType, ArrayData, ArrayLen, Canonical, IntoArrayData, IntoCanonical}; // This is modeled after the constant with the equivalent name in arrow-rs. const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; impl FilterFn for ChunkedEncoding { fn filter(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { - let selected = mask.true_count(); + let mask_values = mask + .values() + .vortex_expect("AllTrue and AllFalse are handled by filter fn"); // Based on filter selectivity, we take the values between a range of slices, or // we take individual indices. - let selectivity = selected as f64 / array.len() as f64; - let chunks = if selectivity > FILTER_SLICES_SELECTIVITY_THRESHOLD { - filter_slices(array, mask) - } else { - filter_indices(array, mask) - }; + let chunks = match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { + MaskIter::Indices(indices) => filter_indices(array, indices.iter().copied()), + MaskIter::Slices(slices) => filter_slices(array, slices.iter().copied()), + }?; - Ok(ChunkedArray::try_new(chunks?, array.dtype().clone())?.into_array()) + Ok(ChunkedArray::try_new(chunks, array.dtype().clone())?.into_array()) } } @@ -39,7 +41,10 @@ enum ChunkFilter { } /// Filter the chunks using slice ranges. -fn filter_slices(array: &ChunkedArray, mask: &Mask) -> VortexResult> { +fn filter_slices( + array: &ChunkedArray, + slices: impl Iterator, +) -> VortexResult> { let mut result = Vec::with_capacity(array.nchunks()); // Pre-materialize the chunk ends for performance. @@ -49,8 +54,8 @@ fn filter_slices(array: &ChunkedArray, mask: &Mask) -> VortexResult VortexResult VortexResult> { +fn filter_indices( + array: &ChunkedArray, + indices: impl Iterator, +) -> VortexResult> { let mut result = Vec::with_capacity(array.nchunks()); let mut current_chunk_id = 0; let mut chunk_indices = BufferMut::with_capacity(array.nchunks()); @@ -125,8 +133,8 @@ fn filter_indices(array: &ChunkedArray, mask: &Mask) -> VortexResult(); - for set_index in mask.indices() { - let (chunk_id, index) = find_chunk_idx(*set_index, chunk_ends); + for set_index in indices { + let (chunk_id, index) = find_chunk_idx(set_index, chunk_ends); if chunk_id != current_chunk_id { // Push the chunk we've accumulated. if !chunk_indices.is_empty() { diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index 1afc1b6733..4a77816e3c 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use vortex_buffer::BufferMut; use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap}; +use vortex_mask::Mask; use crate::array::primitive::PrimitiveArray; use crate::compute::{scalar_at, search_sorted_usize, SearchSortedSide}; @@ -20,7 +21,7 @@ use crate::stats::StatsSet; use crate::stream::{ArrayStream, ArrayStreamAdapter}; use crate::validate::ValidateVTable; use crate::validity::Validity::NonNullable; -use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; +use crate::validity::{ArrayValidity, Validity, ValidityVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ impl_encoding, ArrayDType, ArrayData, ArrayLen, DeserializeMetadata, IntoArrayData, @@ -231,7 +232,7 @@ impl ValidityVTable for ChunkedEncoding { array.chunk(chunk)?.is_valid(offset_in_chunk) } - fn logical_validity(&self, array: &ChunkedArray) -> VortexResult { + fn logical_validity(&self, array: &ChunkedArray) -> VortexResult { // TODO(ngates): implement FromIterator for LogicalValidity. let validity: Validity = array.chunks().map(|a| a.logical_validity()).try_collect()?; validity.to_logical(array.len()) diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index 2df732397b..d4b76b3265 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -4,8 +4,10 @@ mod compare; mod invert; mod search_sorted; +use std::sync::Arc; + use vortex_error::VortexResult; -use vortex_mask::Mask; +use vortex_mask::{Mask, MaskValues}; use vortex_scalar::Scalar; use crate::array::constant::ConstantArray; diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index 3fd1284a97..845bce5d07 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -3,12 +3,13 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; use vortex_error::{VortexExpect, VortexResult}; use vortex_flatbuffers::WriteFlatBuffer; +use vortex_mask::Mask; use vortex_scalar::{Scalar, ScalarValue}; use crate::encoding::ids; use crate::stats::{Stat, StatisticsVTable, StatsSet}; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, ValidityVTable}; +use crate::validity::ValidityVTable; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{impl_encoding, ArrayDType, ArrayLen, EmptyMetadata}; @@ -60,10 +61,10 @@ impl ValidityVTable for ConstantEncoding { Ok(!array.scalar().is_null()) } - fn logical_validity(&self, array: &ConstantArray) -> VortexResult { + fn logical_validity(&self, array: &ConstantArray) -> VortexResult { Ok(match array.scalar().is_null() { - true => LogicalValidity::AllInvalid(array.len()), - false => LogicalValidity::AllValid(array.len()), + true => Mask::AllFalse(array.len()), + false => Mask::AllTrue(array.len()), }) } } diff --git a/vortex-array/src/array/extension/mod.rs b/vortex-array/src/array/extension/mod.rs index 84b61f0640..302c41cb68 100644 --- a/vortex-array/src/array/extension/mod.rs +++ b/vortex-array/src/array/extension/mod.rs @@ -4,11 +4,12 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use vortex_dtype::{DType, ExtDType, ExtID}; use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_mask::Mask; use crate::encoding::ids; use crate::stats::{ArrayStatistics as _, Stat, StatisticsVTable, StatsSet}; use crate::validate::ValidateVTable; -use crate::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use crate::validity::{ArrayValidity, ValidityVTable}; use crate::variants::{ExtensionArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -78,7 +79,7 @@ impl ValidityVTable for ExtensionEncoding { array.storage().is_valid(index) } - fn logical_validity(&self, array: &ExtensionArray) -> VortexResult { + fn logical_validity(&self, array: &ExtensionArray) -> VortexResult { array.storage().logical_validity() } } diff --git a/vortex-array/src/array/list/mod.rs b/vortex-array/src/array/list/mod.rs index 7e4aaf2f56..03d42b01fb 100644 --- a/vortex-array/src/array/list/mod.rs +++ b/vortex-array/src/array/list/mod.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use vortex_dtype::Nullability; use vortex_dtype::{match_each_native_ptype, DType, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect, VortexResult}; +use vortex_mask::Mask; #[cfg(feature = "test-harness")] use vortex_scalar::Scalar; @@ -22,7 +23,7 @@ use crate::compute::{scalar_at, slice}; use crate::encoding::ids; use crate::stats::{StatisticsVTable, StatsSet}; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use crate::validity::{Validity, ValidityMetadata, ValidityVTable}; use crate::variants::{ListArrayTrait, PrimitiveArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -192,7 +193,7 @@ impl ValidityVTable for ListEncoding { array.is_valid(index) } - fn logical_validity(&self, array: &ListArray) -> VortexResult { + fn logical_validity(&self, array: &ListArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 28ce5899bd..123b8839f4 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -63,10 +63,11 @@ impl TakeFn for NullEncoding { mod test { use vortex_buffer::buffer; use vortex_dtype::DType; + use vortex_mask::Mask; use crate::array::null::NullArray; use crate::compute::{scalar_at, slice, take}; - use crate::validity::{ArrayValidity, LogicalValidity}; + use crate::validity::ArrayValidity; use crate::{ArrayLen, IntoArrayData}; #[test] @@ -78,7 +79,7 @@ mod test { assert_eq!(sliced.len(), 4); assert!(matches!( sliced.logical_validity().unwrap(), - LogicalValidity::AllInvalid(4) + Mask::AllFalse(4) )); } @@ -92,7 +93,7 @@ mod test { assert_eq!(taken.len(), 5); assert!(matches!( taken.logical_validity().unwrap(), - LogicalValidity::AllInvalid(5) + Mask::AllFalse(5) )); } diff --git a/vortex-array/src/array/null/mod.rs b/vortex-array/src/array/null/mod.rs index f0d397077e..7acc17053a 100644 --- a/vortex-array/src/array/null/mod.rs +++ b/vortex-array/src/array/null/mod.rs @@ -4,12 +4,13 @@ use serde::{Deserialize, Serialize}; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_mask::Mask; use crate::encoding::ids; use crate::nbytes::ArrayNBytes; use crate::stats::{Stat, StatisticsVTable, StatsSet}; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, Validity, ValidityVTable}; +use crate::validity::{Validity, ValidityVTable}; use crate::variants::{NullArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{impl_encoding, ArrayLen, Canonical, EmptyMetadata, IntoCanonical}; @@ -43,8 +44,8 @@ impl ValidityVTable for NullEncoding { Ok(false) } - fn logical_validity(&self, array: &NullArray) -> VortexResult { - Ok(LogicalValidity::AllInvalid(array.len())) + fn logical_validity(&self, array: &NullArray) -> VortexResult { + Ok(Mask::AllFalse(array.len())) } } diff --git a/vortex-array/src/array/primitive/compute/cast.rs b/vortex-array/src/array/primitive/compute/cast.rs index 48ad063392..c6b1b3469f 100644 --- a/vortex-array/src/array/primitive/compute/cast.rs +++ b/vortex-array/src/array/primitive/compute/cast.rs @@ -23,7 +23,7 @@ impl CastFn for PrimitiveEncoding { // from non-nullable to nullable array.validity().into_nullable() } else if new_nullability == Nullability::NonNullable - && array.validity().to_logical(array.len())?.all_valid() + && array.validity().to_logical(array.len())?.all_true() { // from nullable but all valid, to non-nullable Validity::NonNullable diff --git a/vortex-array/src/array/primitive/compute/fill.rs b/vortex-array/src/array/primitive/compute/fill.rs index 5f8aca7e70..e2769ed612 100644 --- a/vortex-array/src/array/primitive/compute/fill.rs +++ b/vortex-array/src/array/primitive/compute/fill.rs @@ -1,6 +1,8 @@ +use arrow_buffer::BooleanBuffer; use vortex_buffer::Buffer; use vortex_dtype::{match_each_native_ptype, Nullability}; use vortex_error::{vortex_err, VortexResult}; +use vortex_mask::AllOr; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; @@ -16,44 +18,39 @@ impl FillForwardFn for PrimitiveEncoding { return Ok(array.to_array()); } - let validity = array.logical_validity()?; - if validity.all_valid() { - return Ok(PrimitiveArray::from_byte_buffer( + match array.logical_validity()?.boolean_buffer() { + AllOr::All => Ok(PrimitiveArray::from_byte_buffer( array.byte_buffer().clone(), array.ptype(), Validity::AllValid, ) - .into_array()); + .into_array()), + AllOr::None => { + match_each_native_ptype!(array.ptype(), |$T| { + let fill_value = Scalar::from($T::default()).cast(array.dtype())?; + return Ok(ConstantArray::new(fill_value, array.len()).into_array()) + }) + } + AllOr::Some(validity) => { + // TODO(ngates): when we take PrimitiveArray by value, we should mutate in-place + match_each_native_ptype!(array.ptype(), |$T| { + let as_slice = array.as_slice::<$T>(); + let mut last_value = $T::default(); + let filled = Buffer::from_iter( + as_slice + .iter() + .zip(validity.into_iter()) + .map(|(v, valid)| { + if valid { + last_value = *v; + } + last_value + }) + ); + Ok(PrimitiveArray::new(filled, Validity::AllValid).into_array()) + }) + } } - - if validity.all_invalid() { - match_each_native_ptype!(array.ptype(), |$T| { - let fill_value = Scalar::from($T::default()).cast(array.dtype())?; - return Ok(ConstantArray::new(fill_value, array.len()).into_array()) - }) - } - - let nulls = validity - .to_null_buffer()? - .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?; - - // TODO(ngates): when we take PrimitiveArray by value, we should mutate in-place - match_each_native_ptype!(array.ptype(), |$T| { - let as_slice = array.as_slice::<$T>(); - let mut last_value = $T::default(); - let filled = Buffer::from_iter( - as_slice - .iter() - .zip(nulls.into_iter()) - .map(|(v, valid)| { - if valid { - last_value = *v; - } - last_value - }) - ); - Ok(PrimitiveArray::new(filled, Validity::AllValid).into_array()) - }) } } @@ -73,7 +70,7 @@ mod test { PrimitiveArray::from_option_iter([None, Some(8u8), None, Some(10), None]).into_array(); let p = fill_forward(&arr).unwrap().into_primitive().unwrap(); assert_eq!(p.as_slice::(), vec![0, 8, 8, 10, 10]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } #[test] @@ -83,7 +80,7 @@ mod test { let p = fill_forward(&arr).unwrap().into_primitive().unwrap(); assert_eq!(p.as_slice::(), vec![0, 0, 0, 0, 0]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } #[test] @@ -95,6 +92,6 @@ mod test { .into_array(); let p = fill_forward(&arr).unwrap().into_primitive().unwrap(); assert_eq!(p.as_slice::(), vec![8u8, 10, 12, 14, 16]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } } diff --git a/vortex-array/src/array/primitive/compute/fill_null.rs b/vortex-array/src/array/primitive/compute/fill_null.rs index fb98ee7ef7..b1d1aff6ad 100644 --- a/vortex-array/src/array/primitive/compute/fill_null.rs +++ b/vortex-array/src/array/primitive/compute/fill_null.rs @@ -65,7 +65,7 @@ mod test { .into_primitive() .unwrap(); assert_eq!(p.as_slice::(), vec![42, 8, 42, 10, 42]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } #[test] @@ -78,7 +78,7 @@ mod test { .into_primitive() .unwrap(); assert_eq!(p.as_slice::(), vec![255, 255, 255, 255, 255]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } #[test] @@ -93,7 +93,7 @@ mod test { .into_primitive() .unwrap(); assert_eq!(p.as_slice::(), vec![8, 10, 12, 14, 16]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } #[test] @@ -104,6 +104,6 @@ mod test { .into_primitive() .unwrap(); assert_eq!(p.as_slice::(), vec![8u8, 10, 12, 14, 16]); - assert!(p.logical_validity().unwrap().all_valid()); + assert!(p.logical_validity().unwrap().all_true()); } } diff --git a/vortex-array/src/array/primitive/compute/filter.rs b/vortex-array/src/array/primitive/compute/filter.rs index 7392e02201..a92d50e10f 100644 --- a/vortex-array/src/array/primitive/compute/filter.rs +++ b/vortex-array/src/array/primitive/compute/filter.rs @@ -1,24 +1,41 @@ +use std::sync::Arc; + use vortex_buffer::{Buffer, BufferMut}; use vortex_dtype::match_each_native_ptype; -use vortex_error::VortexResult; -use vortex_mask::{Mask, MaskIter}; +use vortex_error::{VortexExpect, VortexResult}; +use vortex_mask::{Mask, MaskIter, MaskValues}; use crate::array::primitive::PrimitiveArray; use crate::array::PrimitiveEncoding; use crate::compute::FilterFn; use crate::variants::PrimitiveArrayTrait; -use crate::{ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData}; + +// This is modeled after the constant with the equivalent name in arrow-rs. +const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; impl FilterFn for PrimitiveEncoding { fn filter(&self, array: &PrimitiveArray, mask: &Mask) -> VortexResult { let validity = array.validity().filter(mask)?; - match_each_native_ptype!(array.ptype(), |$T| { - let values = match mask.iter() { - MaskIter::Indices(indices) => filter_primitive_indices(array.as_slice::<$T>(), indices.iter().copied()), - MaskIter::Slices(slices) => filter_primitive_slices(array.as_slice::<$T>(), mask.true_count(), slices.iter().copied()), - }; - Ok(PrimitiveArray::new(values, validity).into_array()) - }) + + let mask_values = mask + .values() + .vortex_expect("AllTrue and AllFalse are handled by filter fn"); + + match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { + MaskIter::Indices(indices) => { + match_each_native_ptype!(array.ptype(), |$T| { + let values = filter_primitive_indices(array.as_slice::<$T>(), indices.iter().copied()); + Ok(PrimitiveArray::new(values, validity).into_array()) + }) + } + MaskIter::Slices(slices) => { + match_each_native_ptype!(array.ptype(), |$T| { + let values = filter_primitive_slices(array.as_slice::<$T>(), mask.true_count(), slices.iter().copied()); + Ok(PrimitiveArray::new(values, validity).into_array()) + }) + } + } } } diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 8121003ee5..e61a29aef6 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -9,12 +9,13 @@ use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer}; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult}; use vortex_flatbuffers::dtype::Primitive; +use vortex_mask::Mask; use crate::encoding::ids; use crate::iter::Accessor; use crate::stats::StatsSet; use crate::validate::ValidateVTable; -use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use crate::validity::{ArrayValidity, Validity, ValidityMetadata, ValidityVTable}; use crate::variants::{PrimitiveArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -340,7 +341,7 @@ impl ValidityVTable for PrimitiveEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &PrimitiveArray) -> VortexResult { + fn logical_validity(&self, array: &PrimitiveArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index fd63a57813..2ed6c5a47b 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -8,13 +8,14 @@ use num_traits::PrimInt; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability}; use vortex_error::{vortex_panic, VortexError, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::ScalarValue; use crate::array::primitive::PrimitiveArray; use crate::array::PrimitiveEncoding; use crate::nbytes::ArrayNBytes; use crate::stats::{Stat, StatisticsVTable, StatsSet}; -use crate::validity::{ArrayValidity, LogicalValidity}; +use crate::validity::ArrayValidity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, IntoArrayVariant}; @@ -39,12 +40,12 @@ impl StatisticsVTable for PrimitiveEncoding { let mut stats = match_each_native_ptype!(array.ptype(), |$P| { match array.logical_validity()? { - LogicalValidity::AllValid(_) => self.compute_statistics(array.as_slice::<$P>(), stat), - LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, array.dtype())), - LogicalValidity::Mask(m) => self.compute_statistics( + Mask::AllTrue(_) => self.compute_statistics(array.as_slice::<$P>(), stat), + Mask::AllFalse(len) => Ok(StatsSet::nulls(len, array.dtype())), + Mask::Values(v) => self.compute_statistics( &NullableValues( array.as_slice::<$P>(), - m.boolean_buffer(), + v.boolean_buffer(), ), stat ), diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index f740707130..37f9ab4621 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; + use itertools::Itertools; use vortex_error::VortexResult; -use vortex_mask::Mask; +use vortex_mask::{Mask, MaskValues}; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; diff --git a/vortex-array/src/array/struct_/mod.rs b/vortex-array/src/array/struct_/mod.rs index 5c906379a1..6da51c711c 100644 --- a/vortex-array/src/array/struct_/mod.rs +++ b/vortex-array/src/array/struct_/mod.rs @@ -7,11 +7,12 @@ use vortex_dtype::{DType, Field, FieldName, FieldNames, StructDType}; use vortex_error::{ vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult, }; +use vortex_mask::Mask; use crate::encoding::ids; use crate::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use crate::validity::{Validity, ValidityMetadata, ValidityVTable}; use crate::variants::{StructArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -200,7 +201,7 @@ impl ValidityVTable for StructEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &StructArray) -> VortexResult { + fn logical_validity(&self, array: &StructArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/array/varbin/array.rs b/vortex-array/src/array/varbin/array.rs index 96a35d434d..798825f5c7 100644 --- a/vortex-array/src/array/varbin/array.rs +++ b/vortex-array/src/array/varbin/array.rs @@ -1,8 +1,9 @@ use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::array::varbin::VarBinArray; use crate::array::VarBinEncoding; -use crate::validity::{LogicalValidity, ValidityVTable}; +use crate::validity::ValidityVTable; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::ArrayLen; @@ -11,7 +12,7 @@ impl ValidityVTable for VarBinEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &VarBinArray) -> VortexResult { + fn logical_validity(&self, array: &VarBinArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/array/varbin/arrow.rs b/vortex-array/src/array/varbin/arrow.rs index 9852b5359a..b61678aebd 100644 --- a/vortex-array/src/array/varbin/arrow.rs +++ b/vortex-array/src/array/varbin/arrow.rs @@ -26,11 +26,8 @@ pub(crate) fn varbin_to_arrow(varbin_array: &VarBinArray) -> VortexResult for VarBinEncoding { } fn filter_select_var_bin(arr: &VarBinArray, mask: &Mask) -> VortexResult { - let selection_count = mask.true_count(); - if selection_count * 2 > mask.len() { - filter_select_var_bin_by_slice(arr, mask, selection_count) - } else { - filter_select_var_bin_by_index(arr, mask, selection_count) + match mask + .values() + .vortex_expect("AllTrue and AllFalse are handled by filter fn") + .threshold_iter(0.5) + { + MaskIter::Indices(indices) => { + filter_select_var_bin_by_index(arr, indices, mask.true_count()) + } + MaskIter::Slices(slices) => filter_select_var_bin_by_slice(arr, slices, mask.true_count()), } } fn filter_select_var_bin_by_slice( values: &VarBinArray, - mask: &Mask, + mask_slices: &[(usize, usize)], selection_count: usize, ) -> VortexResult { let offsets = values.offsets().into_primitive()?; @@ -38,7 +44,7 @@ fn filter_select_var_bin_by_slice( values.dtype().clone(), offsets.as_slice::<$O>(), values.bytes().as_slice(), - mask, + mask_slices, values.validity(), selection_count ) @@ -50,7 +56,7 @@ fn filter_select_var_bin_by_slice_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask: &Mask, + mask_slices: &[(usize, usize)], validity: Validity, selection_count: usize, ) -> VortexResult @@ -59,12 +65,12 @@ where usize: AsPrimitive, { let logical_validity = validity.to_logical(offsets.len() - 1)?; - if let Some(val) = logical_validity.to_null_buffer()? { + if let AllOr::Some(validity) = logical_validity.boolean_buffer() { let mut builder = VarBinBuilder::::with_capacity(selection_count); - for (start, end) in mask.slices().iter().copied() { - let null_sl = val.slice(start, end - start); - if null_sl.null_count() == 0 { + for (start, end) in mask_slices.iter().copied() { + let null_sl = validity.slice(start, end - start); + if null_sl.count_set_bits() == 0 { update_non_nullable_slice(data, offsets, &mut builder, start, end) } else { for (idx, valid) in null_sl.iter().enumerate() { @@ -94,7 +100,7 @@ where let mut builder = VarBinBuilder::::with_capacity(selection_count); - mask.slices().iter().for_each(|(start, end)| { + mask_slices.iter().for_each(|(start, end)| { update_non_nullable_slice(data, offsets, &mut builder, *start, *end) }); @@ -129,7 +135,7 @@ fn update_non_nullable_slice( fn filter_select_var_bin_by_index( values: &VarBinArray, - mask: &Mask, + mask_indices: &[usize], selection_count: usize, ) -> VortexResult { let offsets = values.offsets().into_primitive()?; @@ -138,7 +144,7 @@ fn filter_select_var_bin_by_index( values.dtype().clone(), offsets.as_slice::<$O>(), values.bytes().as_slice(), - mask, + mask_indices, values.validity(), selection_count ) @@ -150,12 +156,13 @@ fn filter_select_var_bin_by_index_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask: &Mask, + mask_indices: &[usize], + // TODO(ngates): pass LogicalValidity instead validity: Validity, selection_count: usize, ) -> VortexResult { let mut builder = VarBinBuilder::::with_capacity(selection_count); - for idx in mask.indices().iter().copied() { + for idx in mask_indices.iter().copied() { if validity.is_valid(idx)? { let (start, end) = ( offsets[idx].to_usize().ok_or_else(|| { @@ -205,9 +212,7 @@ mod test { ], DType::Utf8(NonNullable), ); - let filter = Mask::from_iter([true, false, true]); - - let buf = filter_select_var_bin_by_index(&arr, &filter, 2) + let buf = filter_select_var_bin_by_index(&arr, &[0, 2], 2) .unwrap() .to_array(); @@ -228,9 +233,8 @@ mod test { ], DType::Utf8(NonNullable), ); - let filter = Mask::from_iter([true, false, true, false, true]); - let buf = filter_select_var_bin_by_slice(&arr, &filter, 3) + let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3), (4, 5)], 3) .unwrap() .to_array(); @@ -258,9 +262,8 @@ mod test { let validity = Validity::Array(BoolArray::from_iter([true, false, true, true, true, true]).to_array()); let arr = VarBinArray::try_new(offsets, bytes, DType::Utf8(Nullable), validity).unwrap(); - let filter = Mask::from_iter([true, true, true, false, true, true]); - let buf = filter_select_var_bin_by_slice(&arr, &filter, 5) + let buf = filter_select_var_bin_by_slice(&arr, &[(0, 3), (4, 6)], 5) .unwrap() .to_array(); diff --git a/vortex-array/src/array/varbin/compute/take.rs b/vortex-array/src/array/varbin/compute/take.rs index 3a8b0ab39a..bb2f369b2e 100644 --- a/vortex-array/src/array/varbin/compute/take.rs +++ b/vortex-array/src/array/varbin/compute/take.rs @@ -38,7 +38,7 @@ fn take( validity: Validity, ) -> VortexResult { let logical_validity = validity.to_logical(offsets.len() - 1)?; - if let Some(v) = logical_validity.to_null_buffer()? { + if let Some(v) = logical_validity.to_null_buffer() { return Ok(take_nullable(dtype, offsets, data, indices, v)); } diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index a6767427e4..0061920d31 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -36,7 +36,7 @@ pub fn compute_varbin_statistics>( Ok(match stat { Stat::NullCount => { - let null_count = array.logical_validity()?.null_count(); + let null_count = array.logical_validity()?.false_count(); if null_count == array.len() { return Ok(StatsSet::nulls(array.len(), array.dtype())); } diff --git a/vortex-array/src/array/varbinview/mod.rs b/vortex-array/src/array/varbinview/mod.rs index 1ac65cead9..68c36ca852 100644 --- a/vortex-array/src/array/varbinview/mod.rs +++ b/vortex-array/src/array/varbinview/mod.rs @@ -12,12 +12,13 @@ use static_assertions::{assert_eq_align, assert_eq_size}; use vortex_buffer::{Alignment, Buffer, ByteBuffer}; use vortex_dtype::DType; use vortex_error::{vortex_bail, vortex_panic, VortexExpect, VortexResult, VortexUnwrap}; +use vortex_mask::Mask; use crate::arrow::FromArrowArray; use crate::encoding::ids; use crate::stats::StatsSet; use crate::validate::ValidateVTable; -use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata, ValidityVTable}; +use crate::validity::{ArrayValidity, Validity, ValidityMetadata, ValidityVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ impl_encoding, ArrayDType, ArrayData, ArrayLen, Canonical, DeserializeMetadata, IntoCanonical, @@ -446,8 +447,7 @@ pub(crate) fn varbinview_as_arrow(var_bin_view: &VarBinViewArray) -> ArrayRef { let nulls = var_bin_view .logical_validity() .vortex_expect("VarBinViewArray: failed to get logical validity") - .to_null_buffer() - .vortex_expect("VarBinViewArray: validity child must be bool"); + .to_null_buffer(); let data = (0..var_bin_view.nbuffers()) .map(|i| var_bin_view.buffer(i)) @@ -483,7 +483,7 @@ impl ValidityVTable for VarBinViewEncoding { array.validity().is_valid(index) } - fn logical_validity(&self, array: &VarBinViewArray) -> VortexResult { + fn logical_validity(&self, array: &VarBinViewArray) -> VortexResult { array.validity().to_logical(array.len()) } } diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index cdb41c2dca..a43526923b 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -170,7 +170,7 @@ fn bool_to_arrow(bool_array: BoolArray, data_type: &DataType) -> VortexResult VortexRes }) .collect::>>()?; - let nulls = struct_array.logical_validity()?.to_null_buffer()?; + let nulls = struct_array.logical_validity()?.to_null_buffer(); if field_arrays.is_empty() { Ok(Arc::new(ArrowStructArray::new_empty_fields( @@ -292,7 +292,7 @@ fn list_to_arrow(list: ListArray, data_type: &DataType) -> VortexResult Arc::new(arrow_array::ListArray::try_new( @@ -319,7 +319,7 @@ fn temporal_to_arrow(temporal_array: TemporalArray) -> VortexResult { &DType::Primitive(<$prim as NativePType>::PTYPE, $values.dtype().nullability()), )? .into_primitive()?; - let nulls = temporal_values.logical_validity()?.to_null_buffer()?; + let nulls = temporal_values.logical_validity()?.to_null_buffer(); let scalars = temporal_values.into_buffer().into_arrow_scalar_buffer(); (scalars, nulls) diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 9287b523a1..91802a09b3 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -248,8 +248,7 @@ mod tests { .validity() .to_logical(indices_bits.len()) .unwrap() - .to_null_buffer() - .unwrap(); + .to_null_buffer(); let is_valid = |idx: usize| match null_buffer.as_ref() { None => true, Some(buffer) => buffer.is_valid(idx), diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 42600f3093..acdabc1b1c 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -6,7 +6,7 @@ use arrow_array::BooleanArray; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect, VortexResult}; -use vortex_mask::Mask; +use vortex_mask::{Mask, MaskValues}; use crate::array::{BoolArray, ConstantArray}; use crate::arrow::FromArrowArray; @@ -17,6 +17,9 @@ use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData, IntoArrayVariant, I pub trait FilterFn { /// Filter an array by the provided predicate. + /// + /// Note that the entry-point filter functions handles `Mask::AllTrue` and `Mask::AllFalse`, + /// leaving only `Mask::Values` to be handled by this function. fn filter(&self, array: &Array, mask: &Mask) -> VortexResult; } @@ -81,8 +84,18 @@ pub fn filter(array: &ArrayData, mask: &Mask) -> VortexResult { } fn filter_impl(array: &ArrayData, mask: &Mask) -> VortexResult { + // Since we handle the AllTrue and AllFalse cases in the entry-point filter function, + // implementations can use `AllOr::expect_some` to unwrap the mixed values variant. + let values = match &mask { + Mask::AllTrue(_) => return Ok(array.clone()), + Mask::AllFalse(_) => return Ok(Canonical::empty(array.dtype()).into_array()), + Mask::Values(values) => values, + }; + if let Some(filter_fn) = array.encoding().filter_fn() { - return filter_fn.filter(array, mask); + let result = filter_fn.filter(array, mask)?; + debug_assert_eq!(result.len(), mask.true_count()); + return Ok(result); } // We can use scalar_at if the mask has length 1. @@ -98,7 +111,7 @@ fn filter_impl(array: &ArrayData, mask: &Mask) -> VortexResult { ); let array_ref = array.clone().into_arrow()?; - let mask_array = BooleanArray::new(mask.boolean_buffer().clone(), None); + let mask_array = BooleanArray::new(values.boolean_buffer().clone(), None); let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?; Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable())) @@ -131,12 +144,6 @@ impl TryFrom for Mask { } } -impl IntoArrayData for Mask { - fn into_array(self) -> ArrayData { - BoolArray::new(self.boolean_buffer().clone(), Nullability::NonNullable).into_array() - } -} - #[cfg(test)] mod test { use super::*; diff --git a/vortex-array/src/data/mod.rs b/vortex-array/src/data/mod.rs index 3f4ca99b0d..4f25045b57 100644 --- a/vortex-array/src/data/mod.rs +++ b/vortex-array/src/data/mod.rs @@ -8,6 +8,7 @@ use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_error::{vortex_err, VortexError, VortexExpect, VortexResult}; use vortex_flatbuffers::FlatBuffer; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::{ @@ -19,7 +20,7 @@ use crate::encoding::{Encoding, EncodingId, EncodingRef, EncodingVTable}; use crate::iter::{ArrayIterator, ArrayIteratorAdapter}; use crate::stats::{ArrayStatistics, Stat, Statistics, StatsSet}; use crate::stream::{ArrayStream, ArrayStreamAdapter}; -use crate::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; +use crate::validity::{ArrayValidity, ValidityVTable}; use crate::{ ArrayChildrenIterator, ArrayDType, ArrayLen, ChildrenCollector, ContextRef, NamedChildrenCollector, @@ -410,7 +411,7 @@ impl> ArrayValidity for A { } /// Return the logical validity of the array if nullable, and None if non-nullable. - fn logical_validity(&self) -> VortexResult { + fn logical_validity(&self) -> VortexResult { ValidityVTable::::logical_validity(self.as_ref().encoding(), self.as_ref()) } } diff --git a/vortex-array/src/encoding/opaque.rs b/vortex-array/src/encoding/opaque.rs index f8dd3dd59f..536ca9b8c1 100644 --- a/vortex-array/src/encoding/opaque.rs +++ b/vortex-array/src/encoding/opaque.rs @@ -3,12 +3,13 @@ use std::fmt::{Debug, Display, Formatter}; use arrow_array::ArrayRef; use vortex_error::{vortex_bail, vortex_panic, VortexResult}; +use vortex_mask::Mask; use crate::compute::ComputeVTable; use crate::encoding::{EncodingId, EncodingVTable}; use crate::stats::StatisticsVTable; use crate::validate::ValidateVTable; -use crate::validity::{LogicalValidity, ValidityVTable}; +use crate::validity::ValidityVTable; use crate::variants::VariantsVTable; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ArrayData, Canonical, EmptyMetadata, IntoCanonicalVTable, MetadataVTable}; @@ -78,7 +79,7 @@ impl ValidityVTable for OpaqueEncoding { ) } - fn logical_validity(&self, _array: &ArrayData) -> VortexResult { + fn logical_validity(&self, _array: &ArrayData) -> VortexResult { vortex_panic!( "OpaqueEncoding: logical_validity cannot be called for opaque array ({})", self.0 diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index cad56a26e0..b069844156 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -1,6 +1,8 @@ use std::cmp::Ordering; use std::fmt::Debug; +use std::sync::Arc; +use arrow_buffer::BooleanBuffer; use itertools::Itertools as _; use num_traits::{AsPrimitive, NumCast, ToPrimitive}; use serde::{Deserialize, Serialize}; @@ -8,7 +10,7 @@ use vortex_buffer::BufferMut; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{match_each_integer_ptype, DType, NativePType, PType}; use vortex_error::{vortex_bail, VortexExpect, VortexResult}; -use vortex_mask::Mask; +use vortex_mask::{AllOr, Mask, MaskValues}; use vortex_scalar::Scalar; use crate::aliases::hash_map::HashMap; @@ -210,18 +212,20 @@ impl Patches { /// Filter the patches by a mask, resulting in new patches for the filtered array. pub fn filter(&self, mask: &Mask) -> VortexResult> { - if mask.true_count() == 0 { - return Ok(None); + match mask.indices() { + AllOr::All => Ok(Some(self.clone())), + AllOr::None => Ok(None), + AllOr::Some(mask_indices) => { + let flat_indices = self.indices().clone().into_primitive()?; + match_each_integer_ptype!(flat_indices.ptype(), |$I| { + filter_patches_with_mask( + flat_indices.as_slice::<$I>(), + self.values(), + mask_indices, + ) + }) + } } - - let flat_indices = self.indices().clone().into_primitive()?; - match_each_integer_ptype!(flat_indices.ptype(), |$I| { - filter_patches_with_mask( - flat_indices.as_slice::<$I>(), - self.values(), - mask - ) - }) } /// Slice the patches by a range of the patched array. @@ -387,10 +391,11 @@ impl Patches { fn filter_patches_with_mask( patch_indices: &[T], patch_values: &ArrayData, - mask: &Mask, + mask_indices: &[usize], ) -> VortexResult> { - let mut new_patch_indices = BufferMut::::with_capacity(mask.true_count()); - let mut new_mask_indices = Vec::with_capacity(mask.true_count()); + let true_count = mask_indices.len(); + let mut new_patch_indices = BufferMut::::with_capacity(true_count); + let mut new_mask_indices = Vec::with_capacity(true_count); // Attempt to move the window by `STRIDE` elements on each iteration. This assumes that // the patches are relatively sparse compared to the overall mask, and so many indices in the @@ -400,9 +405,7 @@ fn filter_patches_with_mask( let mut mask_idx = 0usize; let mut true_idx = 0usize; - let mask_indices = mask.indices(); - - while mask_idx < patch_indices.len() && true_idx < mask.true_count() { + while mask_idx < patch_indices.len() && true_idx < true_count { // NOTE: we are searching for overlaps between sorted, unaligned indices in `patch_indices` // and `mask_indices`. We assume that Patches are sparse relative to the global space of // the mask (which covers both patch and non-patch values of the parent array), and so to @@ -464,7 +467,7 @@ fn filter_patches_with_mask( )?; Ok(Some(Patches::new( - mask.true_count(), + true_count, new_patch_indices, new_patch_values, ))) diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index aa1edc2adf..14f713ca38 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -11,7 +11,8 @@ use vortex_dtype::{DType, Nullability}; use vortex_error::{ vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult, }; -use vortex_mask::Mask; +use vortex_mask::{Mask, MaskValues}; +use vortex_scalar::Scalar; use crate::array::{BoolArray, ConstantArray}; use crate::compute::{filter, scalar_at, slice, take}; @@ -24,15 +25,15 @@ use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; pub trait ValidityVTable { /// Returns whether the `index` item is valid. fn is_valid(&self, array: &Array, index: usize) -> VortexResult { - Ok(self.logical_validity(array)?.is_valid(index)) + Ok(self.logical_validity(array)?.value(index)) } /// Returns the number of invalid elements in the array. fn null_count(&self, array: &Array) -> VortexResult { - Ok(self.logical_validity(array)?.null_count()) + Ok(self.logical_validity(array)?.false_count()) } - fn logical_validity(&self, array: &Array) -> VortexResult; + fn logical_validity(&self, array: &Array) -> VortexResult; } impl ValidityVTable for E @@ -55,7 +56,7 @@ where ValidityVTable::null_count(encoding, array_ref) } - fn logical_validity(&self, array: &ArrayData) -> VortexResult { + fn logical_validity(&self, array: &ArrayData) -> VortexResult { let (array_ref, encoding) = array .try_downcast_ref::() .vortex_expect("Failed to downcast encoding"); @@ -66,7 +67,7 @@ where pub trait ArrayValidity { fn is_valid(&self, index: usize) -> VortexResult; fn null_count(&self) -> VortexResult; - fn logical_validity(&self) -> VortexResult; + fn logical_validity(&self) -> VortexResult; } #[derive( @@ -265,21 +266,11 @@ impl Validity { } } - pub fn to_logical(&self, length: usize) -> VortexResult { + pub fn to_logical(&self, length: usize) -> VortexResult { Ok(match self { - Self::NonNullable => LogicalValidity::AllValid(length), - Self::AllValid => LogicalValidity::AllValid(length), - Self::AllInvalid => LogicalValidity::AllInvalid(length), - Self::Array(a) => { - let mask = Mask::try_from(a.clone())?; - if mask.true_count() == a.len() { - LogicalValidity::AllValid(length) - } else if mask.true_count() == 0 { - LogicalValidity::AllInvalid(length) - } else { - LogicalValidity::Mask(mask) - } - } + Self::NonNullable | Self::AllValid => Mask::AllTrue(length), + Self::AllInvalid => Mask::AllFalse(length), + Self::Array(a) => Mask::try_from(a.clone())?, }) } @@ -421,16 +412,16 @@ impl From for Validity { } } -impl FromIterator for Validity { - fn from_iter>(iter: T) -> Self { - let validities: Vec = iter.into_iter().collect(); +impl FromIterator for Validity { + fn from_iter>(iter: T) -> Self { + let validities: Vec = iter.into_iter().collect(); // If they're all valid, then return a single validity. - if validities.iter().all(|v| v.all_valid()) { + if validities.iter().all(|v| v.all_true()) { return Self::AllValid; } // If they're all invalid, then return a single invalidity. - if validities.iter().all(|v| v.all_invalid()) { + if validities.iter().all(|v| v.all_false()) { return Self::AllInvalid; } @@ -438,10 +429,10 @@ impl FromIterator for Validity { let mut buffer = BooleanBufferBuilder::new(validities.iter().map(|v| v.len()).sum()); for validity in validities { match validity { - LogicalValidity::AllValid(count) => buffer.append_n(count, true), - LogicalValidity::AllInvalid(count) => buffer.append_n(count, false), - LogicalValidity::Mask(mask) => { - buffer.append_buffer(mask.boolean_buffer()); + Mask::AllTrue(count) => buffer.append_n(count, true), + Mask::AllFalse(count) => buffer.append_n(count, false), + Mask::Values(values) => { + buffer.append_buffer(values.boolean_buffer()); } }; } @@ -465,105 +456,36 @@ impl From for Validity { } } -/// Logical validity actually represents "canonical validity". -#[derive(Clone, Debug)] -pub enum LogicalValidity { - AllValid(usize), - AllInvalid(usize), - Mask(Mask), -} - -impl LogicalValidity { - pub fn to_null_buffer(&self) -> VortexResult> { - match self { - Self::AllValid(_) => Ok(None), - Self::AllInvalid(l) => Ok(Some(NullBuffer::new_null(*l))), - Self::Mask(a) => Ok(Some(NullBuffer::new(a.boolean_buffer().clone()))), - } - } - - pub fn to_boolean_buffer(&self) -> BooleanBuffer { - match self { - Self::AllValid(l) => BooleanBuffer::new_set(*l), - Self::AllInvalid(l) => BooleanBuffer::new_unset(*l), - Self::Mask(a) => a.boolean_buffer().clone(), - } - } - - pub fn all_valid(&self) -> bool { - match self { - LogicalValidity::AllValid(_) => true, - LogicalValidity::AllInvalid(_) => false, - LogicalValidity::Mask(mask) => mask.all_true(), - } - } - - pub fn all_invalid(&self) -> bool { - match self { - LogicalValidity::AllValid(_) => false, - LogicalValidity::AllInvalid(_) => true, - LogicalValidity::Mask(mask) => mask.all_false(), - } - } - - pub fn is_valid(&self, index: usize) -> bool { - if index > self.len() { - vortex_panic!("Index out of bounds") - } - match self { - LogicalValidity::AllValid(_) => true, - LogicalValidity::AllInvalid(_) => false, - LogicalValidity::Mask(mask) => mask.value(index), - } - } - - pub fn len(&self) -> usize { - match self { - Self::AllValid(n) => *n, - Self::AllInvalid(n) => *n, - Self::Mask(a) => a.len(), - } - } - - pub fn is_empty(&self) -> bool { - match self { - Self::AllValid(n) => *n == 0, - Self::AllInvalid(n) => *n == 0, - Self::Mask(a) => a.len() == 0, - } - } - - pub fn into_validity(self, nullability: Nullability) -> Validity { +impl Validity { + pub fn from_mask(mask: Mask, nullability: Nullability) -> Self { assert!( - nullability == Nullability::Nullable || matches!(self, Self::AllValid(_)), + nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)), "NonNullable validity must be AllValid", ); - match self { - Self::AllValid(_) => match nullability { + match mask { + Mask::AllTrue(_) => match nullability { Nullability::NonNullable => Validity::NonNullable, Nullability::Nullable => Validity::AllValid, }, - Self::AllInvalid(_) => Validity::AllInvalid, - Self::Mask(a) => Validity::Array(a.into_array()), + Mask::AllFalse(_) => Validity::AllInvalid, + Mask::Values(values) => Validity::Array(values.into_array()), } } +} - pub fn null_count(&self) -> usize { +impl IntoArrayData for Mask { + fn into_array(self) -> ArrayData { match self { - Self::AllValid(_) => 0, - Self::AllInvalid(len) => *len, - Self::Mask(a) => a.len() - a.true_count(), + Self::AllTrue(len) => ConstantArray::new(true, len).into_array(), + Self::AllFalse(len) => ConstantArray::new(false, len).into_array(), + Self::Values(a) => a.into_array(), } } } -impl IntoArrayData for LogicalValidity { +impl IntoArrayData for &MaskValues { fn into_array(self) -> ArrayData { - match self { - Self::AllValid(len) => ConstantArray::new(true, len).into_array(), - Self::AllInvalid(len) => ConstantArray::new(false, len).into_array(), - Self::Mask(a) => a.into_array(), - } + BoolArray::new(self.boolean_buffer().clone(), Nullability::NonNullable).into_array() } } @@ -575,7 +497,7 @@ mod tests { use vortex_mask::Mask; use crate::array::{BoolArray, PrimitiveArray}; - use crate::validity::{LogicalValidity, Validity}; + use crate::validity::Validity; use crate::IntoArrayData; #[rstest] @@ -618,13 +540,12 @@ mod tests { #[test] #[should_panic] fn into_validity_nullable() { - LogicalValidity::AllInvalid(10).into_validity(Nullability::NonNullable); + Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable); } #[test] #[should_panic] fn into_validity_nullable_array() { - LogicalValidity::Mask(Mask::from_iter(vec![true, false])) - .into_validity(Nullability::NonNullable); + Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable); } } diff --git a/vortex-layout/src/layouts/chunked/stats_table.rs b/vortex-layout/src/layouts/chunked/stats_table.rs index 8e19634cfb..060f2759b5 100644 --- a/vortex-layout/src/layouts/chunked/stats_table.rs +++ b/vortex-layout/src/layouts/chunked/stats_table.rs @@ -94,7 +94,7 @@ impl StatsTable { .as_slice::() .iter() .enumerate() - .filter_map(|(i, v)| validity.is_valid(i).then_some(*v)) + .filter_map(|(i, v)| validity.value(i).then_some(*v)) .sum(); stats_set.set(*stat, sum); } diff --git a/vortex-mask/src/bitand.rs b/vortex-mask/src/bitand.rs index 04e90aa08e..58446b99bf 100644 --- a/vortex-mask/src/bitand.rs +++ b/vortex-mask/src/bitand.rs @@ -2,7 +2,7 @@ use std::ops::BitAnd; use vortex_error::vortex_panic; -use crate::Mask; +use crate::{AllOr, Mask}; impl BitAnd for &Mask { type Output = Mask; @@ -11,30 +11,13 @@ impl BitAnd for &Mask { if self.len() != rhs.len() { vortex_panic!("Masks must have the same length"); } - if self.true_count() == 0 || rhs.true_count() == 0 { - return Mask::new_false(self.len()); - } - if self.true_count() == self.len() { - return rhs.clone(); - } - if rhs.true_count() == self.len() { - return self.clone(); - } - if let (Some(lhs), Some(rhs)) = (self.0.buffer.get(), rhs.0.buffer.get()) { - return Mask::from_buffer(lhs & rhs); + match (self.boolean_buffer(), rhs.boolean_buffer()) { + (AllOr::All, _) => rhs.clone(), + (_, AllOr::All) => self.clone(), + (AllOr::None, _) => Mask::new_false(self.len()), + (_, AllOr::None) => Mask::new_false(self.len()), + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs & rhs), } - - if let (Some(lhs), Some(rhs)) = (self.0.indices.get(), rhs.0.indices.get()) { - // TODO(ngates): this may only make sense for sparse indices. - return Mask::from_intersection_indices( - self.len(), - lhs.iter().copied(), - rhs.iter().copied(), - ); - } - - // TODO(ngates): we could perform a more efficient bitandion for slices. - Mask::from_buffer(self.boolean_buffer() & rhs.boolean_buffer()) } } diff --git a/vortex-mask/src/eq.rs b/vortex-mask/src/eq.rs index 917bfa0053..b4a932da6e 100644 --- a/vortex-mask/src/eq.rs +++ b/vortex-mask/src/eq.rs @@ -9,29 +9,7 @@ impl PartialEq for Mask { return false; } - // Since the true counts are the same, a full or empty mask is equal to the other mask. - if self.true_count() == 0 || self.true_count() == self.len() { - return true; - } - - // Compare the buffer if both masks are non-empty. - if let (Some(buffer), Some(other)) = (self.0.buffer.get(), other.0.buffer.get()) { - return buffer == other; - } - - // Compare the indices if both masks are non-empty. - if let (Some(indices), Some(other)) = (self.0.indices.get(), other.0.indices.get()) { - return indices == other; - } - - // Compare the slices if both masks are non-empty. - if let (Some(slices), Some(other)) = (self.0.slices.get(), other.0.slices.get()) { - return slices == other; - } - - // Otherwise, we fall back to comparison based on sparsity. - // We could go further an exhaustively check whose OnceLocks are initialized, but that's - // probably not worth the effort. + // TODO(ngates): we could compare by indices if density is low enough self.boolean_buffer() == other.boolean_buffer() } } diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index 805df56072..3f64970b34 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -1,4 +1,4 @@ -use crate::Mask; +use crate::{AllOr, Mask}; impl Mask { /// Take the intersection of the `mask` with the set of true values in `self`. @@ -22,28 +22,25 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - if mask.true_count() == mask.len() { - return self.clone(); + match (self.indices(), mask.indices()) { + (AllOr::All, _) => mask.clone(), + (_, AllOr::All) => self.clone(), + (AllOr::None, _) => Self::new_false(0), + (_, AllOr::None) => Self::new_false(self.len()), + (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => { + Self::from_indices( + self.len(), + mask_indices + .iter() + .map(|idx| + // This is verified as safe because we know that the indices are less than the + // mask.len() and we known mask.len() <= self.len(), + // implied by `self.true_count() == mask.len()`. + unsafe{*self_indices.get_unchecked(*idx)}) + .collect(), + ) + } } - - if mask.true_count() == 0 { - return Self::new_false(self.len()); - } - - // TODO(joe): support other fast paths, not converting self & mask into indices, - // however indices are better for sparse masks, so this is the common case for now. - let indices = self.0.indices(); - Self::from_indices( - self.len(), - mask.indices() - .iter() - .map(|idx| - // This is verified as safe because we know that the indices are less than the - // mask.len() and we known mask.len() <= self.len(), - // implied by `self.true_count() == mask.len()`. - unsafe{*indices.get_unchecked(*idx)}) - .collect(), - ) } } diff --git a/vortex-mask/src/iter_bools.rs b/vortex-mask/src/iter_bools.rs index bdbb16d406..28f7506c97 100644 --- a/vortex-mask/src/iter_bools.rs +++ b/vortex-mask/src/iter_bools.rs @@ -1,7 +1,6 @@ use std::iter; -use std::iter::{Peekable, TrustedLen}; -use crate::Mask; +use crate::{AllOr, Mask}; impl Mask { /// Provides a closure with an iterator over the boolean values of the mask. @@ -15,133 +14,14 @@ impl Mask { where F: FnMut(&mut dyn Iterator) -> T, { - if self.all_true() { - return f(&mut iter::repeat(true).take(self.len())); - } - - if self.all_false() { - return f(&mut iter::repeat(false).take(self.len())); - } - - // We check for representations in order of performance, with BooleanBuffer iteration last. - - if let Some(indices) = self.0.maybe_indices() { - let mut iter = IndicesBoolIter { - indices: indices.iter().copied().peekable(), - pos: 0, - len: self.len(), - }; - return f(&mut iter); - } - - if let Some(slices) = self.0.maybe_slices() { - let mut iter = SlicesBoolIter { - slices: slices.iter().copied().peekable(), - pos: 0, - len: self.len(), - }; - return f(&mut iter); - } - - if let Some(buffer) = self.0.maybe_buffer() { - return f(&mut buffer.iter()); - } - - unreachable!() - } -} - -struct IndicesBoolIter -where - I: Iterator, -{ - indices: Peekable, - pos: usize, - len: usize, -} - -impl Iterator for IndicesBoolIter -where - I: Iterator, -{ - type Item = bool; - - fn next(&mut self) -> Option { - match self.indices.peek() { - None => { - if self.pos < self.len { - self.pos += 1; - return Some(false); - } - None - } - Some(next) => { - if *next == self.pos { - self.indices.next(); - self.pos += 1; - Some(true) - } else { - self.pos += 1; - Some(false) - } - } + match self.boolean_buffer() { + AllOr::All => f(&mut iter::repeat(true).take(self.len())), + AllOr::None => f(&mut iter::repeat(false).take(self.len())), + AllOr::Some(buffer) => f(&mut buffer.iter()), } } - - fn size_hint(&self) -> (usize, Option) { - let remaining = self.len - self.pos; - (remaining, Some(remaining)) - } -} - -unsafe impl> TrustedLen for IndicesBoolIter {} - -#[allow(dead_code)] -struct SlicesBoolIter -where - I: Iterator, -{ - slices: Peekable, - pos: usize, - len: usize, } -impl Iterator for SlicesBoolIter -where - I: Iterator, -{ - type Item = bool; - - fn next(&mut self) -> Option { - let Some((start, end)) = self.slices.peek() else { - if self.pos < self.len { - self.pos += 1; - return Some(false); - } - return None; - }; - - if self.pos < *start { - self.pos += 1; - return Some(false); - } - - if self.pos == *end - 1 { - self.slices.next(); - } - - self.pos += 1; - Some(true) - } - - fn size_hint(&self) -> (usize, Option) { - let remaining = self.len - self.pos; - (remaining, Some(remaining)) - } -} - -unsafe impl> TrustedLen for SlicesBoolIter {} - #[cfg(test)] mod test { use itertools::Itertools; diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 567701c6bf..a88bdcde2f 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -7,104 +7,142 @@ mod intersect_by_rank; mod iter_bools; use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; use std::sync::{Arc, OnceLock}; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use itertools::Itertools; -use vortex_error::vortex_panic; -/// If the mask selects more than this fraction of rows, iterate over slices instead of indices. -/// -/// Threshold of 0.8 chosen based on Arrow Rust, which is in turn based on: -/// -const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; +/// Represents a set of values that are all included, all excluded, or some mixture of both. +pub enum AllOr { + /// All values are included. + All, + /// No values are included. + None, + /// Some values are included. + Some(T), +} + +impl AllOr { + /// Returns the `Some` variant of the enum, or a default value. + pub fn unwrap_or_else(self, all_true: F, all_false: G) -> T + where + F: FnOnce() -> T, + G: FnOnce() -> T, + { + match self { + Self::Some(v) => v, + AllOr::All => all_true(), + AllOr::None => all_false(), + } + } +} + +impl AllOr<&T> { + /// Clone the inner value. + pub fn cloned(self) -> AllOr + where + T: Clone, + { + match self { + Self::All => AllOr::All, + Self::None => AllOr::None, + Self::Some(v) => AllOr::Some(v.clone()), + } + } +} + +impl Debug for AllOr +where + T: Debug, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::All => f.write_str("All"), + Self::None => f.write_str("None"), + Self::Some(v) => f.debug_tuple("Some").field(v).finish(), + } + } +} + +impl PartialEq for AllOr +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::All, Self::All) => true, + (Self::None, Self::None) => true, + (Self::Some(lhs), Self::Some(rhs)) => lhs == rhs, + _ => false, + } + } +} + +impl Eq for AllOr where T: Eq {} /// Represents a set of sorted unique positive integers. /// /// A [`Mask`] can be constructed from various representations, and converted to various /// others. Internally, these are cached. #[derive(Clone, Debug)] -pub struct Mask(pub(crate) Arc); +pub enum Mask { + /// All values are included. + AllTrue(usize), + /// No values are included. + AllFalse(usize), + /// Some values are included, represented as a [`BooleanBuffer`]. + Values(Arc), +} +/// Represents the values of a [`Mask`] that contains some true and some false elements. #[derive(Debug)] -struct Inner { - // The three possible representations of the mask. - buffer: OnceLock, +pub struct MaskValues { + buffer: BooleanBuffer, + + // We cached the indices and slices representations, since it can be faster than iterating + // the bit-mask over and over again. indices: OnceLock>, slices: OnceLock>, // Pre-computed values. - len: usize, true_count: usize, // i.e., the fraction of values that are true - selectivity: f64, + density: f64, } -impl Inner { - /// Returns a [`BooleanBuffer`] representation of the mask if one exists. - pub(crate) fn maybe_buffer(&self) -> Option<&BooleanBuffer> { - self.buffer.get() +impl MaskValues { + /// Returns the length of the mask. + #[inline] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.buffer.len() } - /// Constructs a [`BooleanBuffer`] from one of the other representations. - fn buffer(&self) -> &BooleanBuffer { - self.buffer.get_or_init(|| { - if self.true_count == 0 { - return BooleanBuffer::new_unset(self.len); - } - - if self.true_count == self.len { - return BooleanBuffer::new_set(self.len); - } - - if let Some(indices) = self.indices.get() { - let mut buf = BooleanBufferBuilder::new(self.len); - // TODO(ngates): for dense indices, we can do better by collecting into u64s. - buf.append_n(self.len, false); - indices.iter().for_each(|idx| buf.set_bit(*idx, true)); - debug_assert_eq!(buf.len(), self.len); - return BooleanBuffer::from(buf); - } - - if let Some(slices) = self.slices.get() { - let mut buf = BooleanBufferBuilder::new(self.len); - for (start, end) in slices.iter().copied() { - buf.append_n(start - buf.len(), false); - buf.append_n(end - start, true); - } - if let Some((_, end)) = slices.last() { - buf.append_n(self.len - end, false); - } - debug_assert_eq!(buf.len(), self.len); - return BooleanBuffer::from(buf); - } + /// Returns the true count of the mask. + pub fn true_count(&self) -> usize { + self.true_count + } - vortex_panic!("No mask representation found") - }) + /// Returns the boolean buffer representation of the mask. + pub fn boolean_buffer(&self) -> &BooleanBuffer { + &self.buffer } - /// Returns the indices representation of the mask if one exists. - pub(crate) fn maybe_indices(&self) -> Option<&[usize]> { - self.indices.get().map(|v| v.as_slice()) + /// Returns the boolean value at a given index. + pub fn value(&self, index: usize) -> bool { + self.buffer.value(index) } /// Constructs an indices vector from one of the other representations. - fn indices(&self) -> &[usize] { + pub fn indices(&self) -> &[usize] { self.indices.get_or_init(|| { if self.true_count == 0 { return vec![]; } - if self.true_count == self.len { - return (0..self.len).collect(); - } - - if let Some(buffer) = self.buffer.get() { - let mut indices = Vec::with_capacity(self.true_count); - indices.extend(buffer.set_indices()); - debug_assert!(indices.is_sorted()); - assert_eq!(indices.len(), self.true_count); - return indices; + if self.true_count == self.len() { + return (0..self.len()).collect(); } if let Some(slices) = self.slices.get() { @@ -115,133 +153,96 @@ impl Inner { return indices; } - vortex_panic!("No mask representation found") + let mut indices = Vec::with_capacity(self.true_count); + indices.extend(self.buffer.set_indices()); + debug_assert!(indices.is_sorted()); + assert_eq!(indices.len(), self.true_count); + indices }) } - /// Returns the slices representation of the mask if one exists. - pub(crate) fn maybe_slices(&self) -> Option<&[(usize, usize)]> { - self.slices.get().map(|v| v.as_slice()) - } - /// Constructs a slices vector from one of the other representations. #[allow(clippy::cast_possible_truncation)] - fn slices(&self) -> &[(usize, usize)] { + pub fn slices(&self) -> &[(usize, usize)] { self.slices.get_or_init(|| { - if self.true_count == self.len { - return vec![(0, self.len)]; + if self.true_count == self.len() { + return vec![(0, self.len())]; } - if let Some(buffer) = self.buffer.get() { - return buffer.set_slices().collect(); - } - - if let Some(indices) = self.indices.get() { - // Expected number of contiguous slices assuming a uniform distribution of true values. - let expected_num_slices = - (self.selectivity * (self.len - self.true_count + 1) as f64).round() as usize; - let mut slices = Vec::with_capacity(expected_num_slices); - let mut iter = indices.iter().copied(); - - // Handle empty input - let Some(first) = iter.next() else { - return slices; - }; - - let mut start = first; - let mut prev = first; - for curr in iter { - if curr != prev + 1 { - slices.push((start, prev + 1)); - start = curr; - } - prev = curr; - } - - // Don't forget the last range - slices.push((start, prev + 1)); - - return slices; - } - - vortex_panic!("No mask representation found") + self.buffer.set_slices().collect() }) } - fn first(&self) -> Option { - if self.true_count == 0 { - return None; - } - if self.true_count == self.len { - return Some(0); - } - if let Some(buffer) = self.buffer.get() { - return buffer.set_indices().next(); - } - if let Some(indices) = self.indices.get() { - return indices.first().copied(); - } - if let Some(slices) = self.slices.get() { - return slices.first().map(|(start, _)| *start); + /// Return an iterator over either indices or slices of the mask based on a density threshold. + pub fn threshold_iter(&self, threshold: f64) -> MaskIter { + if self.density >= threshold { + MaskIter::Slices(self.slices()) + } else { + MaskIter::Indices(self.indices()) } - None } } impl Mask { /// Create a new Mask where all values are set. pub fn new_true(length: usize) -> Self { - Self(Arc::new(Inner { - buffer: Default::default(), - indices: Default::default(), - slices: Default::default(), - len: length, - true_count: length, - selectivity: 1.0, - })) + Self::AllTrue(length) } /// Create a new Mask where no values are set. pub fn new_false(length: usize) -> Self { - Self(Arc::new(Inner { - buffer: Default::default(), - indices: Default::default(), - slices: Default::default(), - len: length, - true_count: 0, - selectivity: 0.0, - })) + Self::AllFalse(length) } /// Create a new [`Mask`] from a [`BooleanBuffer`]. pub fn from_buffer(buffer: BooleanBuffer) -> Self { - let true_count = buffer.count_set_bits(); let len = buffer.len(); - Self(Arc::new(Inner { - buffer: OnceLock::from(buffer), + let true_count = buffer.count_set_bits(); + + if true_count == 0 { + return Self::AllFalse(len); + } + if true_count == len { + return Self::AllTrue(len); + } + + Self::Values(Arc::new(MaskValues { + buffer, indices: Default::default(), slices: Default::default(), - len, true_count, - selectivity: true_count as f64 / len as f64, + density: true_count as f64 / len as f64, })) } /// Create a new [`Mask`] from a [`Vec`]. - pub fn from_indices(len: usize, vec: Vec) -> Self { - let true_count = vec.len(); - assert!(vec.is_sorted(), "Mask indices must be sorted"); + pub fn from_indices(len: usize, indices: Vec) -> Self { + let true_count = indices.len(); + assert!(indices.is_sorted(), "Mask indices must be sorted"); assert!( - vec.last().is_none_or(|&idx| idx < len), + indices.last().is_none_or(|&idx| idx < len), "Mask indices must be in bounds (len={len})" ); - Self(Arc::new(Inner { - buffer: Default::default(), - indices: OnceLock::from(vec), + + if true_count == 0 { + return Self::AllFalse(len); + } + if true_count == len { + return Self::AllTrue(len); + } + + let mut buf = BooleanBufferBuilder::new(len); + // TODO(ngates): for dense indices, we can do better by collecting into u64s. + buf.append_n(len, false); + indices.iter().for_each(|idx| buf.set_bit(*idx, true)); + debug_assert_eq!(buf.len(), len); + + Self::Values(Arc::new(MaskValues { + buffer: buf.finish(), + indices: OnceLock::from(indices), slices: Default::default(), - len, true_count, - selectivity: true_count as f64 / len as f64, + density: true_count as f64 / len as f64, })) } @@ -252,18 +253,34 @@ impl Mask { Self::from_slices_unchecked(len, vec) } - fn from_slices_unchecked(len: usize, vec: Vec<(usize, usize)>) -> Self { + fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self { #[cfg(debug_assertions)] - Self::check_slices(len, &vec); + Self::check_slices(len, &slices); - let true_count = vec.iter().map(|(b, e)| e - b).sum(); - Self(Arc::new(Inner { - buffer: Default::default(), + let true_count = slices.iter().map(|(b, e)| e - b).sum(); + if true_count == 0 { + return Self::AllFalse(len); + } + if true_count == len { + return Self::AllTrue(len); + } + + let mut buf = BooleanBufferBuilder::new(len); + for (start, end) in slices.iter().copied() { + buf.append_n(start - buf.len(), false); + buf.append_n(end - start, true); + } + if let Some((_, end)) = slices.last() { + buf.append_n(len - end, false); + } + debug_assert_eq!(buf.len(), len); + + Self::Values(Arc::new(MaskValues { + buffer: buf.finish(), indices: Default::default(), - slices: OnceLock::from(vec), - len, + slices: OnceLock::from(slices), true_count, - selectivity: true_count as f64 / len as f64, + density: true_count as f64 / len as f64, })) } @@ -315,28 +332,44 @@ impl Mask { /// Returns the length of the mask (not the number of true values). #[inline] - // There is no good definition of is_empty, does it mean len == 0 or true_count == 0? + // It's confusing to provide is_empty, does it mean len == 0 or true_count == 0? #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { - self.0.len + match &self { + Self::AllTrue(len) => *len, + Self::AllFalse(len) => *len, + Self::Values(values) => values.buffer.len(), + } } /// Get the true count of the mask. #[inline] pub fn true_count(&self) -> usize { - self.0.true_count + match &self { + Self::AllTrue(len) => *len, + Self::AllFalse(_) => 0, + Self::Values(values) => values.true_count, + } } /// Get the false count of the mask. #[inline] pub fn false_count(&self) -> usize { - self.len() - self.true_count() + match &self { + Self::AllTrue(_) => 0, + Self::AllFalse(len) => *len, + Self::Values(values) => values.buffer.len() - values.true_count, + } } /// Returns true if all values in the mask are true. #[inline] pub fn all_true(&self) -> bool { - self.true_count() == self.len() + match &self { + Self::AllTrue(_) => true, + Self::AllFalse(_) => false, + Self::Values(values) => values.buffer.len() == values.true_count, + } } /// Returns true if all values in the mask are false. @@ -345,10 +378,14 @@ impl Mask { self.true_count() == 0 } - /// Return the selectivity of the full mask. + /// Return the density of the full mask. #[inline] - pub fn selectivity(&self) -> f64 { - self.0.selectivity + pub fn density(&self) -> f64 { + match &self { + Self::AllTrue(_) => 1.0, + Self::AllFalse(_) => 0.0, + Self::Values(values) => values.density, + } } /// Returns the boolean value at a given index. @@ -357,85 +394,101 @@ impl Mask { /// /// Panics if the index is out of bounds. pub fn value(&self, idx: usize) -> bool { - if self.all_true() { - return true; + match self { + Mask::AllTrue(_) => true, + Mask::AllFalse(_) => false, + Mask::Values(values) => values.buffer.value(idx), } - if self.all_false() { - return false; - } - // NOTE(ngates): we should follow up and make BooleanBuffer the default impl for this. - self.0.buffer().value(idx) } - /// Get the canonical representation of the mask. - pub fn boolean_buffer(&self) -> &BooleanBuffer { - self.0.buffer() + /// Returns the first true index in the mask. + pub fn first(&self) -> Option { + match &self { + Self::AllTrue(len) => (*len > 0).then_some(0), + Self::AllFalse(_) => None, + Self::Values(values) => { + if let Some(indices) = values.indices.get() { + return indices.first().copied(); + } + if let Some(slices) = values.slices.get() { + return slices.first().map(|(start, _)| *start); + } + values.buffer.set_indices().next() + } + } } - /// Get the indices of the true values in the mask. - pub fn indices(&self) -> &[usize] { - self.0.indices() + /// Slice the mask. + pub fn slice(&self, offset: usize, length: usize) -> Self { + assert!(offset + length <= self.len()); + match &self { + Self::AllTrue(_) => Self::new_true(length), + Self::AllFalse(_) => Self::new_false(length), + Self::Values(values) => Self::from_buffer(values.buffer.slice(offset, length)), + } } - /// Get the slices of the true values in the mask. - pub fn slices(&self) -> &[(usize, usize)] { - self.0.slices() + /// Return the boolean buffer representation of the mask. + pub fn boolean_buffer(&self) -> AllOr<&BooleanBuffer> { + match &self { + Self::AllTrue(_) => AllOr::All, + Self::AllFalse(_) => AllOr::None, + Self::Values(values) => AllOr::Some(&values.buffer), + } } - /// Returns the first true index in the mask. - pub fn first(&self) -> Option { - self.0.first() + /// Return a boolean buffer representation of the mask, allocating new buffers for all-true + /// and all-false variants. + pub fn to_boolean_buffer(&self) -> BooleanBuffer { + match self { + Self::AllTrue(l) => BooleanBuffer::new_set(*l), + Self::AllFalse(l) => BooleanBuffer::new_unset(*l), + Self::Values(values) => values.boolean_buffer().clone(), + } } - /// Returns the best iterator based on a selectivity threshold. - /// - /// Currently, this threshold is fixed at 0.8 based on Arrow Rust. - pub fn iter(&self) -> MaskIter { - if self.selectivity() > FILTER_SLICES_SELECTIVITY_THRESHOLD { - MaskIter::Slices(self.slices()) - } else { - MaskIter::Indices(self.indices()) + /// Returns an Arrow null buffer representation of the mask. + pub fn to_null_buffer(&self) -> Option { + match self { + Mask::AllTrue(_) => None, + Mask::AllFalse(l) => Some(NullBuffer::new_null(*l)), + Mask::Values(values) => Some(NullBuffer::from(values.buffer.clone())), } } - /// Slice the mask. - pub fn slice(&self, offset: usize, length: usize) -> Self { - if self.true_count() == 0 { - return Self::new_false(length); - } - if self.true_count() == self.len() { - return Self::new_true(length); + /// Return the indices representation of the mask. + pub fn indices(&self) -> AllOr<&[usize]> { + match &self { + Self::AllTrue(_) => AllOr::All, + Self::AllFalse(_) => AllOr::None, + Self::Values(values) => AllOr::Some(values.indices()), } + } - if let Some(buffer) = self.0.buffer.get() { - return Self::from_buffer(buffer.slice(offset, length)); + /// Return the slices representation of the mask. + pub fn slices(&self) -> AllOr<&[(usize, usize)]> { + match &self { + Self::AllTrue(_) => AllOr::All, + Self::AllFalse(_) => AllOr::None, + Self::Values(values) => AllOr::Some(values.slices()), } + } - let end = offset + length; - - if let Some(indices) = self.0.indices.get() { - let indices = indices - .iter() - .copied() - .skip_while(|idx| *idx < offset) - .take_while(|idx| *idx < end) - .map(|idx| idx - offset) - .collect(); - return Self::from_indices(length, indices); + /// Return an iterator over either indices or slices of the mask based on a density threshold. + pub fn threshold_iter(&self, threshold: f64) -> AllOr { + match &self { + Self::AllTrue(_) => AllOr::All, + Self::AllFalse(_) => AllOr::None, + Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)), } + } - if let Some(slices) = self.0.slices.get() { - let slices = slices - .iter() - .copied() - .skip_while(|(_, e)| *e <= offset) - .take_while(|(s, _)| *s < end) - .map(|(s, e)| (s.max(offset), e.min(end))) - .collect(); - return Self::from_slices_unchecked(length, slices); + /// Return [`MaskValues`] if the mask is not all true or all false. + pub fn values(&self) -> Option<&MaskValues> { + match self { + Self::Values(values) => Some(values), + _ => None, } - - vortex_panic!("No mask representation found") } } @@ -468,10 +521,10 @@ mod test { let mask = Mask::new_true(5); assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 5); - assert_eq!(mask.selectivity(), 1.0); - assert_eq!(mask.indices(), &[0, 1, 2, 3, 4]); - assert_eq!(mask.slices(), &[(0, 5)]); - assert_eq!(mask.boolean_buffer(), &BooleanBuffer::new_set(5)); + assert_eq!(mask.density(), 1.0); + assert_eq!(mask.indices(), AllOr::All); + assert_eq!(mask.slices(), AllOr::All); + assert_eq!(mask.boolean_buffer(), AllOr::All,); } #[test] @@ -479,10 +532,10 @@ mod test { let mask = Mask::new_false(5); assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 0); - assert_eq!(mask.selectivity(), 0.0); - assert_eq!(mask.indices(), &[] as &[usize]); - assert_eq!(mask.slices(), &[]); - assert_eq!(mask.boolean_buffer(), &BooleanBuffer::new_unset(5)); + assert_eq!(mask.density(), 0.0); + assert_eq!(mask.indices(), AllOr::None); + assert_eq!(mask.slices(), AllOr::None); + assert_eq!(mask.boolean_buffer(), AllOr::None,); } #[test] @@ -496,12 +549,12 @@ mod test { for mask in &masks { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 3); - assert_eq!(mask.selectivity(), 0.6); - assert_eq!(mask.indices(), &[0, 2, 3]); - assert_eq!(mask.slices(), &[(0, 1), (2, 4)]); + assert_eq!(mask.density(), 0.6); + assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..])); + assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..])); assert_eq!( - &mask.boolean_buffer().iter().collect::>(), - &[true, false, true, true, false] + mask.boolean_buffer(), + AllOr::Some(&BooleanBuffer::from_iter([true, false, true, true, false])) ); } } diff --git a/vortex-sampling-compressor/src/compressors/for.rs b/vortex-sampling-compressor/src/compressors/for.rs index 439890ca7e..e4af197281 100644 --- a/vortex-sampling-compressor/src/compressors/for.rs +++ b/vortex-sampling-compressor/src/compressors/for.rs @@ -34,7 +34,7 @@ impl EncodingCompressor for FoRCompressor { } // For all-null, cannot encode. - if parray.logical_validity().ok()?.all_invalid() { + if parray.logical_validity().ok()?.all_false() { return None; } diff --git a/vortex-scan/src/range_scan.rs b/vortex-scan/src/range_scan.rs index 6867110cc1..bd3cdecd3c 100644 --- a/vortex-scan/src/range_scan.rs +++ b/vortex-scan/src/range_scan.rs @@ -131,7 +131,7 @@ impl RangeScanner { self.state = State::Ready(None); } else if !conjuncts_rev.is_empty() { self.mask = mask; - let mask = if self.mask.selectivity() < APPLY_FILTER_SELECTIVITY_THRESHOLD { + let mask = if self.mask.density() < APPLY_FILTER_SELECTIVITY_THRESHOLD { self.mask.clone() } else { Mask::new_true(self.mask.len()) diff --git a/vortex-scan/src/row_mask.rs b/vortex-scan/src/row_mask.rs index a577dbcf91..2aa5d51f26 100644 --- a/vortex-scan/src/row_mask.rs +++ b/vortex-scan/src/row_mask.rs @@ -1,10 +1,10 @@ use std::cmp::{max, min}; use std::fmt::{Display, Formatter}; -use std::ops::{BitAnd, RangeBounds}; +use std::ops::RangeBounds; use vortex_array::array::BooleanBuffer; use vortex_array::compute::{filter, slice, try_cast}; -use vortex_array::validity::{ArrayValidity, LogicalValidity}; +use vortex_array::validity::ArrayValidity; use vortex_array::{ArrayDType, ArrayData, IntoArrayVariant}; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, PType}; @@ -79,13 +79,7 @@ impl RowMask { /// /// True-valued positions are kept by the returned mask. fn from_mask_array(array: &ArrayData, begin: u64) -> VortexResult { - match array.logical_validity()? { - LogicalValidity::AllValid(_) => Ok(Self::new(Mask::try_from(array.clone())?, begin)), - LogicalValidity::AllInvalid(_) => { - Ok(Self::new_invalid_between(begin, begin + array.len() as u64)) - } - LogicalValidity::Mask(mask) => Ok(Self::new(mask, begin)), - } + Ok(Self::new(array.logical_validity()?, begin)) } /// Construct a RowMask from an integral array. @@ -137,50 +131,6 @@ impl RowMask { self.end <= start || end <= self.begin } - /// Perform an intersection with another [`RowMask`], returning only rows that appear in both. - pub fn and_rowmask(self, other: RowMask) -> VortexResult { - if other.true_count() == other.len() { - return Ok(self); - } - - // If both masks align perfectly - if self.begin == other.begin && self.end == other.end { - return Ok(RowMask::new(self.mask.bitand(&other.mask), self.begin)); - } - - // Disjoint row ranges - if self.end <= other.begin || self.begin >= other.end { - return Ok(RowMask::new_invalid_between( - min(self.begin, other.begin), - max(self.end, other.end), - )); - } - - let output_begin = min(self.begin, other.begin); - let output_end = max(self.end, other.end); - let output_len = usize::try_from(output_end - output_begin) - .map_err(|_| vortex_err!("Range length does not fit into a usize"))?; - - let output_mask = Mask::from_intersection_indices( - output_len, - self.mask - .indices() - .iter() - .copied() - .map(|v| v as u64 + self.begin - output_begin) - .map(|v| usize::try_from(v).vortex_expect("mask index must fit into usize")), - other - .mask - .indices() - .iter() - .copied() - .map(|v| v as u64 + other.begin - output_begin) - .map(|v| usize::try_from(v).vortex_expect("mask index must fit into usize")), - ); - - Ok(Self::new(output_mask, output_begin)) - } - /// The beginning of the masked range. #[inline] pub fn begin(&self) -> u64 { @@ -357,70 +307,4 @@ mod tests { let array = PrimitiveArray::new(buffer![1.0, 2.0], Validity::AllInvalid).into_array(); RowMask::from_array(&array, 0, 2).unwrap(); } - - #[test] - fn test_and_rowmap_disjoint() { - let a = RowMask::from_array( - PrimitiveArray::new(buffer![1, 2, 3], Validity::AllValid).as_ref(), - 0, - 10, - ) - .unwrap(); - let b = RowMask::from_array( - PrimitiveArray::new(buffer![1, 2, 3], Validity::AllValid).as_ref(), - 15, - 20, - ) - .unwrap(); - - let output = a.and_rowmask(b).unwrap(); - - assert_eq!(output.begin, 0); - assert_eq!(output.end, 20); - assert_eq!(output.true_count(), 0); - } - - #[test] - fn test_and_rowmap_aligned() { - let a = RowMask::from_array( - PrimitiveArray::new(buffer![1, 2, 3], Validity::AllValid).as_ref(), - 0, - 10, - ) - .unwrap(); - let b = RowMask::from_array( - PrimitiveArray::new(buffer![1, 2, 7], Validity::AllValid).as_ref(), - 0, - 10, - ) - .unwrap(); - - let output = a.and_rowmask(b).unwrap(); - - assert_eq!(output.begin, 0); - assert_eq!(output.end, 10); - assert_eq!(output.true_count(), 2); - } - - #[test] - fn test_and_rowmap_intersect() { - let a = RowMask::from_array( - PrimitiveArray::new(buffer![1, 2, 3], Validity::AllValid).as_ref(), - 0, - 10, - ) - .unwrap(); - let b = RowMask::from_array( - PrimitiveArray::new(buffer!(1, 2, 7), Validity::AllValid).as_ref(), - 5, - 15, - ) - .unwrap(); - - let output = a.and_rowmask(b).unwrap(); - - assert_eq!(output.begin, 0); - assert_eq!(output.end, 15); - assert_eq!(output.true_count(), 0); - } }