Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Jul 21, 2024
1 parent cb23d03 commit 24ffae7
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 16 deletions.
1 change: 0 additions & 1 deletion arrow-arith/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-ord = { workspace = true }
chrono = { workspace = true }
half = { version = "2.1", default-features = false }
num = { version = "0.4", default-features = false, features = ["std"] }
Expand Down
27 changes: 15 additions & 12 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use arrow_array::iterator::ArrayIter;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_ord::cmp::compare_byte_view_unchecked;
use arrow_schema::*;
use std::borrow::BorrowMut;
use std::cmp::{self, Ordering};
Expand Down Expand Up @@ -415,6 +414,8 @@ where

/// Helper to compute min/max of [`GenericByteViewArray<T>`].
/// The specialized min/max leverages the inlined values to compare the byte views.
/// `swap_cond` is the condition to swap current min/max with the new value.
/// For example, `Ordering::Greater` for max and `Ordering::Less` for min.
fn min_max_view_helper<T: ByteViewType>(
array: &GenericByteViewArray<T>,
swap_cond: cmp::Ordering,
Expand All @@ -423,30 +424,32 @@ fn min_max_view_helper<T: ByteViewType>(
if null_count == array.len() {
None
} else if null_count == 0 {
let min_idx = (0..array.len()).reduce(|acc, item| {
let target_idx = (0..array.len()).reduce(|acc, item| {
// SAFETY: array's length is correct so item is within bounds
let cmp = unsafe { compare_byte_view_unchecked(array, acc, array, item) };
let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) };
if cmp == swap_cond {
item
} else {
acc
}
});
// Safety: idx came from valid range `0..array.len()`
unsafe { min_idx.map(|idx| array.value_unchecked(idx)) }
// SAFETY: idx came from valid range `0..array.len()`
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
} else {
let nulls = array.nulls().unwrap();

let min_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
let cmp = unsafe { compare_byte_view_unchecked(array, acc_idx, array, idx) };
let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
let cmp =
unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) };
if cmp == swap_cond {
idx
} else {
acc_idx
}
});

unsafe { min_idx.map(|idx| array.value_unchecked(idx)) }
// SAFETY: idx came from valid range `0..array.len()`
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
}
}

Expand All @@ -457,7 +460,7 @@ pub fn max_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&

/// Returns the maximum value in the binary view array, according to the natural order.
pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
min_max_view_helper(array, Ordering::Less)
min_max_view_helper(array, Ordering::Greater)
}

/// Returns the minimum value in the binary array, according to the natural order.
Expand All @@ -467,7 +470,7 @@ pub fn min_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&

/// Returns the minimum value in the binary view array, according to the natural order.
pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
min_max_view_helper(array, Ordering::Greater)
min_max_view_helper(array, Ordering::Less)
}

/// Returns the maximum value in the string array, according to the natural order.
Expand All @@ -477,7 +480,7 @@ pub fn max_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&

/// Returns the maximum value in the string view array, according to the natural order.
pub fn max_string_view(array: &StringViewArray) -> Option<&str> {
min_max_view_helper(array, Ordering::Less)
min_max_view_helper(array, Ordering::Greater)
}

/// Returns the minimum value in the string array, according to the natural order.
Expand All @@ -487,7 +490,7 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&

/// Returns the minimum value in the string view array, according to the natural order.
pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
min_max_view_helper(array, Ordering::Greater)
min_max_view_helper(array, Ordering::Less)
}

/// Returns the sum of values in the array.
Expand Down
60 changes: 60 additions & 0 deletions arrow-array/src/array/byte_view_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,66 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {

builder.finish()
}

/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
///
/// Comparing two ByteView types are non-trivial.
/// It takes a bit of patience to understand why we don't just compare two &[u8] directly.
///
/// ByteView types give us the following two advantages, and we need to be careful not to lose them:
/// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view.
/// Meaning that reading one array element requires only one memory access
/// (two memory access required for StringArray, one for offset buffer, the other for value buffer).
///
/// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray,
/// thanks to the inlined 4 bytes.
/// Consider equality check:
/// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access).
///
/// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary.
/// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer,
/// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string.
///
/// # Order check flow
/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
/// (2.1) if the inlined 4 bytes are different, we can return the result immediately.
/// (2.2) o.w., we need to compare the full string.
///
/// # Safety
/// The left/right_idx must within range of each array
pub unsafe fn compare_unchecked(
left: &GenericByteViewArray<T>,
left_idx: usize,
right: &GenericByteViewArray<T>,
right_idx: usize,
) -> std::cmp::Ordering {
let l_view = left.views().get_unchecked(left_idx);
let l_len = *l_view as u32;

let r_view = right.views().get_unchecked(right_idx);
let r_len = *r_view as u32;

if l_len <= 12 && r_len <= 12 {
let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
return l_data.cmp(r_data);
}

// one of the string is larger than 12 bytes,
// we then try to compare the inlined data first
let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
if r_inlined_data != l_inlined_data {
return l_inlined_data.cmp(r_inlined_data);
}

// unfortunately, we need to compare the full data
let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() };
let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() };

l_full_data.cmp(r_full_data)
}
}

impl<T: ByteViewType + ?Sized> Debug for GenericByteViewArray<T> {
Expand Down
7 changes: 4 additions & 3 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
return false;
}

unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() }
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() }
}

fn is_lt(l: Self::Item, r: Self::Item) -> bool {
// # Safety
// The index is within bounds as it is checked in value()
unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() }
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() }
}

fn len(&self) -> usize {
Expand Down Expand Up @@ -626,7 +626,7 @@ pub fn compare_byte_view<T: ByteViewType>(
) -> std::cmp::Ordering {
assert!(left_idx < left.len());
assert!(right_idx < right.len());
unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) }
unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) }
}

/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
Expand Down Expand Up @@ -656,6 +656,7 @@ pub fn compare_byte_view<T: ByteViewType>(
///
/// # Safety
/// The left/right_idx must within range of each array
#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")]
pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
left: &GenericByteViewArray<T>,
left_idx: usize,
Expand Down

0 comments on commit 24ffae7

Please sign in to comment.