Skip to content

Commit

Permalink
fix: Dict LikeFn length mismatch (#2043)
Browse files Browse the repository at this point in the history
before this PR, LikeFn takes a `pattern: &ArrayData` argument, but for
the DictArray implementation, it would only (accidentally) return a
correct result if `pattern` was a `ConstantArray`
  • Loading branch information
lwwmanning authored Jan 21, 2025
1 parent 047228b commit 5486c92
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
14 changes: 11 additions & 3 deletions encodings/dict/src/compute/like.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use vortex_array::array::ConstantArray;
use vortex_array::compute::{like, LikeFn, LikeOptions};
use vortex_array::{ArrayData, IntoArrayData};
use vortex_error::VortexResult;
Expand All @@ -10,8 +11,15 @@ impl LikeFn<DictArray> for DictEncoding {
array: DictArray,
pattern: &ArrayData,
options: LikeOptions,
) -> VortexResult<ArrayData> {
let values = like(array.values(), pattern, options)?;
Ok(DictArray::try_new(array.codes(), values)?.into_array())
) -> VortexResult<Option<ArrayData>> {
if let Some(pattern) = pattern.as_constant() {
let pattern = ConstantArray::new(pattern, array.values().len()).into_array();
let values = like(array.values(), &pattern, options)?;
Ok(Some(
DictArray::try_new(array.codes(), values)?.into_array(),
))
} else {
Ok(None)
}
}
}
42 changes: 30 additions & 12 deletions vortex-array/src/compute/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub trait LikeFn<Array> {
array: Array,
pattern: &ArrayData,
options: LikeOptions,
) -> VortexResult<ArrayData>;
) -> VortexResult<Option<ArrayData>>;
}

impl<E: Encoding> LikeFn<ArrayData> for E
Expand All @@ -24,7 +24,7 @@ where
array: ArrayData,
pattern: &ArrayData,
options: LikeOptions,
) -> VortexResult<ArrayData> {
) -> VortexResult<Option<ArrayData>> {
let encoding = array
.encoding()
.as_any()
Expand Down Expand Up @@ -58,20 +58,32 @@ pub fn like(
if !matches!(pattern.dtype(), DType::Utf8(..)) {
vortex_bail!("Expected utf8 pattern, got {}", array.dtype());
}
if array.len() != pattern.len() {
vortex_bail!(
"Length mismatch lhs len {} ({}) != rhs len {} ({})",
array.len(),
array.encoding().id(),
pattern.len(),
pattern.encoding().id()
);
}

let expected_dtype =
DType::Bool((array.dtype().is_nullable() || pattern.dtype().is_nullable()).into());
let array_encoding = array.encoding().id();

let result = if let Some(f) = array.encoding().like_fn() {
f.like(array, pattern, options)
} else {
// Otherwise, we canonicalize into a UTF8 array.
log::debug!(
"No like implementation found for encoding {}",
array.encoding().id(),
);
arrow_like(array, pattern, options)
}?;
let result = array
.encoding()
.like_fn()
.and_then(|f| f.like(array.clone(), pattern, options).transpose())
.unwrap_or_else(|| {
// Otherwise, we canonicalize into a UTF8 array.
log::debug!(
"No like implementation found for encoding {}",
array.encoding().id(),
);
arrow_like(array, pattern, options)
})?;

debug_assert_eq!(
result.len(),
Expand All @@ -97,6 +109,12 @@ pub(crate) fn arrow_like(
) -> VortexResult<ArrayData> {
let nullable = array.dtype().is_nullable();
let len = array.len();
debug_assert_eq!(
array.len(),
pattern.len(),
"Arrow Like: length mismatch for {}",
array.encoding().id()
);
let lhs = unsafe { Datum::try_new(array)? };
let rhs = unsafe { Datum::try_new(pattern.clone())? };

Expand Down

0 comments on commit 5486c92

Please sign in to comment.