Skip to content

Commit

Permalink
Refactor Vortex Mask (#2101)
Browse files Browse the repository at this point in the history
We now explicitly expose an AllTrue, AllFalse, and Values (mixed)
variant.
  • Loading branch information
gatesn authored Jan 28, 2025
1 parent e67f256 commit 9ef1107
Show file tree
Hide file tree
Showing 65 changed files with 788 additions and 1,022 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions encodings/alp/src/alp/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex_array::encoding::ids;
use vortex_array::patches::{Patches, PatchesMetadata};
use vortex_array::stats::StatisticsVTable;
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
use vortex_array::validity::{ArrayValidity, ValidityVTable};
use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
Expand All @@ -15,6 +15,7 @@ use vortex_array::{
};
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult};
use vortex_mask::Mask;

use crate::alp::{alp_encode, decompress, Exponents};

Expand Down Expand Up @@ -124,7 +125,7 @@ impl ValidityVTable<ALPArray> for ALPEncoding {
array.encoded().is_valid(index)
}

fn logical_validity(&self, array: &ALPArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &ALPArray) -> VortexResult<Mask> {
array.encoded().logical_validity()
}
}
Expand Down
11 changes: 5 additions & 6 deletions encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use vortex_array::encoding::ids;
use vortex_array::patches::{Patches, PatchesMetadata};
use vortex_array::stats::{StatisticsVTable, StatsSet};
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
impl_encoding, ArrayDType, ArrayData, ArrayLen, Canonical, IntoCanonical, SerdeMetadata,
};
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::{vortex_bail, VortexExpect, VortexResult};
use vortex_mask::Mask;

use crate::alp_rd::alp_rd_decode;

Expand Down Expand Up @@ -210,8 +211,7 @@ impl IntoCanonical for ALPRDArray {
right_parts.into_buffer_mut::<u32>(),
self.left_parts_patches(),
)?,
self.logical_validity()?
.into_validity(self.dtype().nullability()),
Validity::from_mask(self.logical_validity()?, self.dtype().nullability()),
)
} else {
PrimitiveArray::new(
Expand All @@ -222,8 +222,7 @@ impl IntoCanonical for ALPRDArray {
right_parts.into_buffer_mut::<u64>(),
self.left_parts_patches(),
)?,
self.logical_validity()?
.into_validity(self.dtype().nullability()),
Validity::from_mask(self.logical_validity()?, self.dtype().nullability()),
)
};

Expand All @@ -237,7 +236,7 @@ impl ValidityVTable<ALPRDArray> for ALPRDEncoding {
array.left_parts().is_valid(index)
}

fn logical_validity(&self, array: &ALPRDArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &ALPRDArray) -> VortexResult<Mask> {
// Use validity from left_parts
array.left_parts().logical_validity()
}
Expand Down
1 change: 1 addition & 0 deletions encodings/bytebool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ vortex-array = { workspace = true }
vortex-buffer = { workspace = true }
vortex-dtype = { workspace = true }
vortex-error = { workspace = true }
vortex-mask = { workspace = true }
vortex-scalar = { workspace = true }

[dev-dependencies]
Expand Down
5 changes: 3 additions & 2 deletions encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use vortex_array::array::BoolArray;
use vortex_array::encoding::ids;
use vortex_array::stats::StatsSet;
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable};
use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable};
use vortex_array::variants::{BoolArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{impl_encoding, ArrayLen, Canonical, IntoCanonical, SerdeMetadata};
use vortex_buffer::ByteBuffer;
use vortex_dtype::DType;
use vortex_error::{VortexExpect as _, VortexResult};
use vortex_mask::Mask;

impl_encoding!(
"vortex.bytebool",
Expand Down Expand Up @@ -116,7 +117,7 @@ impl ValidityVTable<ByteBoolArray> for ByteBoolEncoding {
array.validity().is_valid(index)
}

fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &ByteBoolArray) -> VortexResult<Mask> {
array.validity().to_logical(array.len())
}
}
Expand Down
19 changes: 9 additions & 10 deletions encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use num_traits::AsPrimitive;
use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity};
use vortex_array::validity::{ArrayValidity, Validity};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
use vortex_dtype::{match_each_integer_ptype, Nullability};
use vortex_error::{vortex_err, VortexResult};
use vortex_mask::Mask;
use vortex_scalar::Scalar;

use super::{ByteBoolArray, ByteBoolEncoding};
Expand Down Expand Up @@ -55,7 +56,7 @@ impl TakeFn<ByteBoolArray> for ByteBoolEncoding {
// FIXME(ngates): we should be operating over canonical validity, which doesn't
// have fallible is_valid function.
let arr = match validity {
LogicalValidity::AllValid(_) => {
Mask::AllTrue(_) => {
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
indices.as_slice::<$I>()
.iter()
Expand All @@ -68,16 +69,14 @@ impl TakeFn<ByteBoolArray> for ByteBoolEncoding {

ByteBoolArray::from(bools).into_array()
}
LogicalValidity::AllInvalid(_) => {
ByteBoolArray::from(vec![None; indices.len()]).into_array()
}
LogicalValidity::Mask(mask) => {
Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
Mask::Values(values) => {
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
indices.as_slice::<$I>()
.iter()
.map(|&idx| {
let idx = idx.as_();
if mask.value(idx) {
if values.value(idx) {
Some(bools[idx])
} else {
None
Expand All @@ -101,21 +100,21 @@ impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {
return Ok(array.to_array());
}
// all valid, but we need to convert to non-nullable
if validity.all_valid() {
if validity.all_true() {
return Ok(
ByteBoolArray::try_new(array.buffer().clone(), Validity::AllValid)?.into_array(),
);
}
// all invalid => fill with default value (false)
if validity.all_invalid() {
if validity.all_false() {
return Ok(
ByteBoolArray::try_from_vec(vec![false; array.len()], Validity::AllValid)?
.into_array(),
);
}

let validity = validity
.to_null_buffer()?
.to_null_buffer()
.ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;

let bools = array.as_slice();
Expand Down
13 changes: 7 additions & 6 deletions encodings/datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use vortex_array::compute::try_cast;
use vortex_array::encoding::ids;
use vortex_array::stats::StatsSet;
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable};
use vortex_array::validity::{ArrayValidity, Validity, ValidityVTable};
use vortex_array::variants::{ExtensionArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, IntoArrayData, SerdeMetadata};
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult, VortexUnwrap};
use vortex_mask::Mask;

impl_encoding!(
"vortex.datetimeparts",
Expand Down Expand Up @@ -100,10 +101,10 @@ impl DateTimePartsArray {

pub fn validity(&self) -> VortexResult<Validity> {
// FIXME(ngates): this function is weird... can we just use logical validity?
Ok(self
.days()
.logical_validity()?
.into_validity(self.dtype().nullability()))
Ok(Validity::from_mask(
self.days().logical_validity()?,
self.dtype().nullability(),
))
}
}

Expand Down Expand Up @@ -140,7 +141,7 @@ impl ValidityVTable<DateTimePartsArray> for DateTimePartsEncoding {
array.days().is_valid(index)
}

fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &DateTimePartsArray) -> VortexResult<Mask> {
array.days().logical_validity()
}
}
Expand Down
8 changes: 4 additions & 4 deletions encodings/dict/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex_array::compute::{scalar_at, take};
use vortex_array::encoding::ids;
use vortex_array::stats::StatsSet;
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
use vortex_array::validity::{ArrayValidity, ValidityVTable};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
Expand Down Expand Up @@ -90,7 +90,7 @@ impl ValidityVTable<DictArray> for DictEncoding {
array.values().is_valid(values_index)
}

fn logical_validity(&self, array: &DictArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &DictArray) -> VortexResult<Mask> {
if array.dtype().is_nullable() {
let primitive_codes = array.codes().into_primitive()?;
match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
Expand All @@ -99,10 +99,10 @@ impl ValidityVTable<DictArray> for DictEncoding {
let is_valid_buffer = BooleanBuffer::collect_bool(is_valid.len(), |idx| {
is_valid[idx] != 0
});
Ok(LogicalValidity::Mask(Mask::from_buffer(is_valid_buffer)))
Ok(Mask::from_buffer(is_valid_buffer))
})
} else {
Ok(LogicalValidity::AllValid(array.len()))
Ok(Mask::AllTrue(array.len()))
}
}
}
Expand Down
1 change: 0 additions & 1 deletion encodings/fastlanes/src/bitpacking/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ mod test {
.unwrap()
.to_null_buffer()
.unwrap()
.unwrap()
.into_inner()
.set_indices()
.collect::<Vec<_>>()
Expand Down
34 changes: 12 additions & 22 deletions encodings/fastlanes/src/bitpacking/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant};
use vortex_buffer::{Buffer, BufferMut};
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_mask::{Mask, MaskIter};
use vortex_error::{VortexExpect, VortexResult};
use vortex_mask::Mask;

use super::chunked_indices;
use crate::bitpacking::compute::take::UNPACK_CHUNK_THRESHOLD;
Expand Down Expand Up @@ -43,17 +43,20 @@ fn filter_primitive<T: NativePType + BitPacking + ArrowNativeType>(
.flatten();

// Short-circuit if the selectivity is high enough.
if mask.selectivity() > 0.8 {
if mask.density() > 0.8 {
return filter(array.clone().into_primitive()?.as_ref(), mask)
.and_then(|a| a.into_primitive());
}

let values: Buffer<T> = match mask.iter() {
MaskIter::Indices(indices) => {
filter_indices(array, mask.true_count(), indices.iter().copied())
}
MaskIter::Slices(slices) => filter_slices(array, mask.true_count(), slices.iter().copied()),
};
let values: Buffer<T> = filter_indices(
array,
mask.true_count(),
mask.values()
.vortex_expect("AllTrue and AllFalse handled by filter fn")
.indices()
.iter()
.copied(),
);

let mut values = PrimitiveArray::new(values, validity).reinterpret_cast(array.ptype());
if let Some(patches) = patches {
Expand Down Expand Up @@ -111,19 +114,6 @@ fn filter_indices<T: NativePType + BitPacking + ArrowNativeType>(
values.freeze()
}

fn filter_slices<T: NativePType + BitPacking + ArrowNativeType>(
array: &BitPackedArray,
indices_len: usize,
slices: impl Iterator<Item = (usize, usize)>,
) -> Buffer<T> {
// TODO(ngates): do this more efficiently.
filter_indices(
array,
indices_len,
slices.into_iter().flat_map(|(start, end)| start..end),
)
}

#[cfg(test)]
mod test {
use vortex_array::array::PrimitiveArray;
Expand Down
5 changes: 3 additions & 2 deletions encodings/fastlanes/src/bitpacking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vortex_array::encoding::ids;
use vortex_array::patches::{Patches, PatchesMetadata};
use vortex_array::stats::{StatisticsVTable, StatsSet};
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable};
use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable};
use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
Expand All @@ -16,6 +16,7 @@ use vortex_array::{
use vortex_buffer::ByteBuffer;
use vortex_dtype::{DType, NativePType, PType};
use vortex_error::{vortex_bail, vortex_err, VortexExpect as _, VortexResult};
use vortex_mask::Mask;

mod compress;
mod compute;
Expand Down Expand Up @@ -262,7 +263,7 @@ impl ValidityVTable<BitPackedArray> for BitPackedEncoding {
array.validity().is_valid(index)
}

fn logical_validity(&self, array: &BitPackedArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &BitPackedArray) -> VortexResult<Mask> {
array.validity().to_logical(array.len())
}
}
Expand Down
5 changes: 3 additions & 2 deletions encodings/fastlanes/src/delta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use vortex_array::array::PrimitiveArray;
use vortex_array::encoding::ids;
use vortex_array::stats::{StatisticsVTable, StatsSet};
use vortex_array::validate::ValidateVTable;
use vortex_array::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTable};
use vortex_array::validity::{Validity, ValidityMetadata, ValidityVTable};
use vortex_array::variants::{PrimitiveArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
Expand All @@ -15,6 +15,7 @@ use vortex_array::{
use vortex_buffer::Buffer;
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult};
use vortex_mask::Mask;

mod compress;
mod compute;
Expand Down Expand Up @@ -243,7 +244,7 @@ impl ValidityVTable<DeltaArray> for DeltaEncoding {
array.validity().is_valid(index)
}

fn logical_validity(&self, array: &DeltaArray) -> VortexResult<LogicalValidity> {
fn logical_validity(&self, array: &DeltaArray) -> VortexResult<Mask> {
array.validity().to_logical(array.len())
}
}
Expand Down
Loading

0 comments on commit 9ef1107

Please sign in to comment.