Skip to content

Commit

Permalink
fix: support fp16 type in SQ (#3417)
Browse files Browse the repository at this point in the history
  • Loading branch information
chebbyChefNEQ authored Jan 26, 2025
1 parent 5a92d31 commit 58c5e27
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
21 changes: 17 additions & 4 deletions rust/lance-index/src/vector/sq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

use std::ops::Range;

use arrow::compute::concat_batches;
use arrow::{compute::concat_batches, datatypes::Float16Type};
use arrow_array::{
cast::AsArray,
types::{Float32Type, UInt64Type, UInt8Type},
ArrayRef, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_schema::SchemaRef;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use deepsize::DeepSizeOf;
use lance_core::{Error, Result, ROW_ID};
Expand Down Expand Up @@ -391,8 +391,21 @@ pub struct SQDistCalculator<'a> {

impl<'a> SQDistCalculator<'a> {
fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
let query_sq_code =
scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds);
// This is okay-ish to use hand-rolled dynamic dispatch here
// since we search 10s-100s of partitions, we can afford the overhead
// this could be annoying at indexing time for HNSW, which requires constructing the
// dist calculator frequently. However, HNSW isn't first-class citizen in Lance yet. so be it.
let query_sq_code = match query.data_type() {
DataType::Float16 => {
scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
}
DataType::Float32 => {
scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
}
_ => {
panic!("Unsupported data type for ScalarQuantizationStorage");
}
};
Self {
query_sq_code,
bounds,
Expand Down
29 changes: 16 additions & 13 deletions rust/lance/src/io/exec/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,22 +495,25 @@ impl ExecutionPlan for ANNIvfSubIndexExec {
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
let plan = if children.len() == 1 || children.len() == 2 {
if children.len() == 2 {
let _prefilter = children.pop().expect("length checked");
}
// NOTE!!!! Prefilter transformation is ignored.
Self {
input: children.pop().expect("length checked"),
dataset: self.dataset.clone(),
indices: self.indices.clone(),
query: self.query.clone(),
prefilter_source: self.prefilter_source.clone(),
properties: self.properties.clone(),
}
} else {
return Err(DataFusionError::Internal(
"ANNSubIndexExec node must have exactly one child".to_string(),
"ANNSubIndexExec node must have exactly one or two (prefilter) child".to_string(),
));
}

let new_plan = Self {
input: children.pop().expect("length checked"),
dataset: self.dataset.clone(),
indices: self.indices.clone(),
query: self.query.clone(),
prefilter_source: self.prefilter_source.clone(),
properties: self.properties.clone(),
};

Ok(Arc::new(new_plan))
Ok(Arc::new(plan))
}

fn execute(
Expand Down

0 comments on commit 58c5e27

Please sign in to comment.