Skip to content

Commit

Permalink
Patching Bitpacked and ALP arrays doesn't require multiple copies (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Apr 4, 2024
1 parent 02d8180 commit 88d5fd0
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 41 deletions.
19 changes: 15 additions & 4 deletions vortex-alp/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use itertools::Itertools;
use vortex::array::downcast::DowncastArrayBuiltin;
use vortex::array::primitive::PrimitiveArray;
use vortex::array::sparse::SparseArray;
use vortex::array::sparse::{SparseArray, SparseEncoding};
use vortex::array::{Array, ArrayRef};
use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression};
use vortex::compute::flatten::flatten_primitive;
use vortex::compute::patch::patch;
use vortex::ptype::{NativePType, PType};
use vortex::scalar::Scalar;
use vortex_error::{vortex_bail, vortex_err, VortexResult};
Expand Down Expand Up @@ -126,13 +125,25 @@ pub fn decompress(array: &ALPArray) -> VortexResult<PrimitiveArray> {
})?;

if let Some(patches) = array.patches() {
// TODO(#121): right now, applying patches forces an extraneous copy of the array data
flatten_primitive(&patch(&decoded, patches)?)
patch_decoded(decoded, patches)
} else {
Ok(decoded)
}
}

fn patch_decoded(array: PrimitiveArray, patches: &dyn Array) -> VortexResult<PrimitiveArray> {
match patches.encoding().id() {
SparseEncoding::ID => {
match_each_alp_float_ptype!(array.ptype(), |$T| {
array.patch(
&patches.as_sparse().resolved_indices(),
flatten_primitive(patches.as_sparse().values())?.typed_data::<$T>())?
})
}
_ => panic!("can't patch alp array with {}", patches),
}
}

fn decompress_primitive<T: NativePType + ALPFloat>(
values: &[T::ALPInt],
exponents: &Exponents,
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn cast<P: NativePType, T: NativePType>(array: &[P]) -> VortexResult<Vec<T>> {
// TODO(ngates): allow configurable checked/unchecked casting
.map(|&v| {
T::from(v)
.ok_or_else(|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", v, T::PTYPE))
.ok_or_else(|| vortex_err!(ComputeError: "Failed to cast {} to {}", v, T::PTYPE))
})
.collect()
}
Expand Down Expand Up @@ -69,6 +69,6 @@ mod test {
let VortexError::ComputeError(s, _) = error else {
unreachable!()
};
assert_eq!(s.to_string(), "Failed to cast -1 to U32");
assert_eq!(s.to_string(), "Failed to cast -1 to u32");
}
}
45 changes: 40 additions & 5 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ use std::sync::{Arc, RwLock};

use allocator_api2::alloc::Allocator;
use arrow_buffer::buffer::{Buffer, ScalarBuffer};
use itertools::Itertools;
use linkme::distributed_slice;
use num_traits::AsPrimitive;
pub use view::*;
use vortex_error::{vortex_bail, VortexResult};
use vortex_schema::{DType, Nullability};

use crate::accessor::ArrayAccessor;
use crate::array::primitive::compute::PrimitiveTrait;
use crate::array::validity::{Validity, ValidityView};
use crate::array::IntoArray;
use crate::array::{check_slice_bounds, Array, ArrayRef};
use crate::compute::ArrayCompute;
use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS};
use crate::formatter::{ArrayDisplay, ArrayFormatter};
use crate::iterator::ArrayIter;
Expand All @@ -28,11 +33,6 @@ mod serde;
mod stats;
mod view;

pub use view::*;

use crate::array::primitive::compute::PrimitiveTrait;
use crate::compute::ArrayCompute;

#[derive(Debug, Clone)]
pub struct PrimitiveArray {
buffer: Buffer,
Expand Down Expand Up @@ -151,10 +151,45 @@ impl PrimitiveArray {
self.buffer().typed_data()
}

pub fn patch<P: AsPrimitive<usize>, T: NativePType>(
mut self,
positions: &[P],
values: &[T],
) -> VortexResult<Self> {
if self.ptype() != T::PTYPE {
vortex_bail!(MismatchedTypes: self.dtype, T::PTYPE)
}

let mut own_values = self
.buffer
.into_vec::<T>()
.unwrap_or_else(|b| Vec::from(b.typed_data::<T>()));
// TODO(robert): Also patch validity
for (idx, value) in positions.iter().zip_eq(values.iter()) {
own_values[(*idx).as_()] = *value;
}
self.buffer = Buffer::from_vec::<T>(own_values);
Ok(self)
}

pub(crate) fn as_trait<T: NativePType>(&self) -> &dyn PrimitiveTrait<T> {
assert_eq!(self.ptype, T::PTYPE);
self
}

pub fn reinterpret_cast(&self, ptype: PType) -> Self {
if self.ptype() == ptype {
return self.clone();
}

assert_eq!(
self.ptype().byte_width(),
ptype.byte_width(),
"can't reinterpret cast between integers of two different widths"
);

PrimitiveArray::new(ptype, self.buffer().clone(), self.validity())
}
}

impl Array for PrimitiveArray {
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/ptype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ macro_rules! match_each_integer_ptype {
PType::U16 => __with__! { u16 },
PType::U32 => __with__! { u32 },
PType::U64 => __with__! { u64 },
_ => panic!("Unsupported ptype {:?}", $self),
_ => panic!("Unsupported ptype {}", $self),
}
})
}
Expand Down
7 changes: 3 additions & 4 deletions vortex-dict/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ impl ScalarAtFn for DictArray {
impl TakeFn for DictArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
let codes = take(self.codes(), indices)?;
// TODO(ngates): Add function to remove unused entries from dictionary
Ok(DictArray::new(codes, self.values().clone()).to_array())
Ok(DictArray::new(codes, self.values().clone()).into_array())
}
}

Expand All @@ -59,7 +58,7 @@ mod test {
let reference =
PrimitiveArray::from_iter(vec![Some(42), Some(-9), None, Some(42), None, Some(-9)]);
let (codes, values) = dict_encode_typed_primitive::<i32>(&reference);
let dict = DictArray::new(codes.to_array(), values.to_array());
let dict = DictArray::new(codes.into_array(), values.into_array());
let flattened_dict = flatten_primitive(&dict).unwrap();
assert_eq!(flattened_dict.buffer(), reference.buffer());
}
Expand All @@ -71,7 +70,7 @@ mod test {
DType::Utf8(Nullability::Nullable),
);
let (codes, values) = dict_encode_varbin(&reference);
let dict = DictArray::new(codes.to_array(), values.to_array());
let dict = DictArray::new(codes.into_array(), values.into_array());
let flattened_dict = flatten_varbin(&dict).unwrap();
assert_eq!(
flattened_dict.offsets().as_primitive().buffer(),
Expand Down
37 changes: 23 additions & 14 deletions vortex-fastlanes/src/bitpacking/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ use arrayref::array_ref;
use fastlanez::TryBitPack;
use vortex::array::downcast::DowncastArrayBuiltin;
use vortex::array::primitive::PrimitiveArray;
use vortex::array::sparse::SparseArray;
use vortex::array::sparse::{SparseArray, SparseEncoding};
use vortex::array::IntoArray;
use vortex::array::{Array, ArrayRef};
use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression};
use vortex::compute::cast::cast;
use vortex::compute::flatten::flatten_primitive;
use vortex::compute::patch::patch;
use vortex::match_each_integer_ptype;
use vortex::ptype::PType::{I16, I32, I64, I8, U16, U32, U64, U8};
use vortex::ptype::{NativePType, PType};
Expand Down Expand Up @@ -103,13 +102,12 @@ impl EncodingCompression for BitPackedEncoding {
fn bitpack(parray: &PrimitiveArray, bit_width: usize) -> ArrayRef {
// We know the min is > 0, so it's safe to re-interpret signed integers as unsigned.
// TODO(ngates): we should implement this using a vortex cast to centralize this hack.
use PType::*;
let bytes = match parray.ptype() {
I8 | U8 => bitpack_primitive(parray.buffer().typed_data::<u8>(), bit_width),
I16 | U16 => bitpack_primitive(parray.buffer().typed_data::<u16>(), bit_width),
I32 | U32 => bitpack_primitive(parray.buffer().typed_data::<u32>(), bit_width),
I64 | U64 => bitpack_primitive(parray.buffer().typed_data::<u64>(), bit_width),
_ => panic!("Unsupported ptype {:?}", parray.ptype()),
_ => panic!("Unsupported ptype {}", parray.ptype()),
};
PrimitiveArray::from(bytes).into_array()
}
Expand Down Expand Up @@ -165,7 +163,7 @@ fn bitpack_patches(
pub fn unpack(array: &BitPackedArray) -> VortexResult<PrimitiveArray> {
let bit_width = array.bit_width();
let length = array.len();
let encoded = flatten_primitive(cast(array.encoded(), PType::U8.into())?.as_ref())?;
let encoded = flatten_primitive(cast(array.encoded(), U8.into())?.as_ref())?;
let ptype: PType = array.dtype().try_into()?;

let mut unpacked = match ptype {
Expand All @@ -185,21 +183,32 @@ pub fn unpack(array: &BitPackedArray) -> VortexResult<PrimitiveArray> {
unpack_primitive::<u64>(encoded.typed_data::<u8>(), bit_width, length),
array.validity(),
),
_ => panic!("Unsupported ptype {:?}", ptype),
}
.into_array();
_ => panic!("Unsupported ptype {}", ptype),
};

// Cast to signed if necessary
// TODO(ngates): do this more efficiently since we know it's a safe cast. unchecked_cast maybe?
if ptype.is_signed_int() {
unpacked = cast(&unpacked, &ptype.into())?
unpacked = unpacked.reinterpret_cast(ptype);
}

if let Some(patches) = array.patches() {
unpacked = patch(unpacked.as_ref(), patches)?;
patch_unpacked(unpacked, patches)
} else {
Ok(unpacked)
}
}

flatten_primitive(&unpacked)
fn patch_unpacked(array: PrimitiveArray, patches: &dyn Array) -> VortexResult<PrimitiveArray> {
match patches.encoding().id() {
SparseEncoding::ID => {
match_each_integer_ptype!(array.ptype(), |$T| {
array.patch(
&patches.as_sparse().resolved_indices(),
flatten_primitive(patches.as_sparse().values())?.typed_data::<$T>())
})
}
_ => panic!("can't patch bitpacked array with {}", patches),
}
}

pub fn unpack_primitive<T: NativePType + TryBitPack>(
Expand Down Expand Up @@ -244,7 +253,7 @@ pub fn unpack_primitive<T: NativePType + TryBitPack>(

pub(crate) fn unpack_single(array: &BitPackedArray, index: usize) -> VortexResult<Scalar> {
let bit_width = array.bit_width();
let encoded = flatten_primitive(cast(array.encoded(), PType::U8.into())?.as_ref())?;
let encoded = flatten_primitive(cast(array.encoded(), U8.into())?.as_ref())?;
let ptype: PType = array.dtype().try_into()?;

let scalar: Scalar = unsafe {
Expand All @@ -263,7 +272,7 @@ pub(crate) fn unpack_single(array: &BitPackedArray, index: usize) -> VortexResul
unpack_single_primitive::<u64>(encoded.typed_data::<u8>(), bit_width, index)
.map(|v| v.into())
}
_ => vortex_bail!("Unsupported ptype {:?}", ptype),
_ => vortex_bail!("Unsupported ptype {}", ptype),
}?
};

Expand Down
2 changes: 2 additions & 0 deletions vortex-ree/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl EncodingCompression for REEEncoding {
compressed_ends,
compressed_values,
ctx.compress_validity(primitive_array.validity())?,
array.len(),
)
.into_array())
}
Expand Down Expand Up @@ -194,6 +195,7 @@ mod test {
vec![2u32, 5, 10].into_array(),
vec![1i32, 2, 3].into_array(),
Some(validity),
10,
);

let decoded = ree_decode(
Expand Down
25 changes: 18 additions & 7 deletions vortex-ree/src/ree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::{Arc, RwLock};
use vortex::array::validity::Validity;
use vortex::array::{check_slice_bounds, Array, ArrayKind, ArrayRef};
use vortex::compress::EncodingCompression;
use vortex::compute::scalar_at::scalar_at;
use vortex::compute::search_sorted::SearchSortedSide;
use vortex::compute::ArrayCompute;
use vortex::encoding::{Encoding, EncodingId, EncodingRef};
Expand All @@ -27,16 +26,21 @@ pub struct REEArray {
}

impl REEArray {
pub fn new(ends: ArrayRef, values: ArrayRef, validity: Option<Validity>) -> Self {
Self::try_new(ends, values, validity).unwrap()
pub fn new(
ends: ArrayRef,
values: ArrayRef,
validity: Option<Validity>,
length: usize,
) -> Self {
Self::try_new(ends, values, validity, length).unwrap()
}

pub fn try_new(
ends: ArrayRef,
values: ArrayRef,
validity: Option<Validity>,
length: usize,
) -> VortexResult<Self> {
let length: usize = scalar_at(ends.as_ref(), ends.len() - 1)?.try_into()?;
if let Some(v) = &validity {
assert_eq!(v.len(), length);
}
Expand All @@ -49,6 +53,7 @@ impl REEArray {
vortex_bail!("Ends array must be strictly sorted",);
}

// TODO(ngates): https://github.com/fulcrum-so/spiral/issues/873
Ok(Self {
ends,
values,
Expand All @@ -71,10 +76,13 @@ impl REEArray {
match ArrayKind::from(array) {
ArrayKind::Primitive(p) => {
let (ends, values) = ree_encode(p);
Ok(
REEArray::new(ends.into_array(), values.into_array(), p.validity())
.into_array(),
Ok(REEArray::new(
ends.into_array(),
values.into_array(),
p.validity(),
p.len(),
)
.into_array())
}
_ => Err(vortex_err!("REE can only encode primitive arrays")),
}
Expand Down Expand Up @@ -212,6 +220,7 @@ mod test {
vec![2u32, 5, 10].into_array(),
vec![1i32, 2, 3].into_array(),
None,
10,
);
assert_eq!(arr.len(), 10);
assert_eq!(
Expand All @@ -234,6 +243,7 @@ mod test {
vec![2u32, 5, 10].into_array(),
vec![1i32, 2, 3].into_array(),
None,
10,
)
.slice(3, 8)
.unwrap();
Expand All @@ -255,6 +265,7 @@ mod test {
vec![2u32, 5, 10].into_array(),
vec![1i32, 2, 3].into_array(),
None,
10,
);
assert_eq!(
flatten_primitive(&arr).unwrap().typed_data::<i32>(),
Expand Down
5 changes: 4 additions & 1 deletion vortex-ree/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{REEArray, REEEncoding};

impl ArraySerde for REEArray {
fn write(&self, ctx: &mut WriteCtx) -> VortexResult<()> {
ctx.write_usize(self.len())?;
ctx.write_validity(self.validity())?;
// TODO(robert): Stop writing this
ctx.dtype(self.ends().dtype())?;
Expand All @@ -20,11 +21,12 @@ impl ArraySerde for REEArray {

impl EncodingSerde for REEEncoding {
fn read(&self, ctx: &mut ReadCtx) -> VortexResult<ArrayRef> {
let len = ctx.read_usize()?;
let validity = ctx.read_validity()?;
let ends_dtype = ctx.dtype()?;
let ends = ctx.with_schema(&ends_dtype).read()?;
let values = ctx.read()?;
Ok(REEArray::new(ends, values, validity).into_array())
Ok(REEArray::new(ends, values, validity, len).into_array())
}
}

Expand Down Expand Up @@ -55,6 +57,7 @@ mod test {
vec![0u8, 9, 20, 32, 49].into_array(),
vec![-7i64, -13, 17, 23].into_array(),
None,
49,
);
let read_arr = roundtrip_array(&arr).unwrap();
let read_ree = read_arr.as_ree();
Expand Down
Loading

0 comments on commit 88d5fd0

Please sign in to comment.