diff --git a/ffi/examples/read-table/arrow.c b/ffi/examples/read-table/arrow.c index d58a2fa2d..4f660038e 100644 --- a/ffi/examples/read-table/arrow.c +++ b/ffi/examples/read-table/arrow.c @@ -11,6 +11,7 @@ ArrowContext* init_arrow_context() context->num_batches = 0; context->batches = NULL; context->cur_filter = NULL; + context->cur_transform = NULL; return context; } @@ -50,86 +51,10 @@ static GArrowRecordBatch* get_record_batch(FFI_ArrowArray* array, GArrowSchema* return record_batch; } -// Add columns to a record batch for each partition. In a "real" engine we would want to parse the -// string values into the correct data type. This program just adds all partition columns as strings -// for simplicity -static GArrowRecordBatch* add_partition_columns( - GArrowRecordBatch* record_batch, - PartitionList* partition_cols, - const CStringMap* partition_values) -{ - gint64 rows = garrow_record_batch_get_n_rows(record_batch); - gint64 cols = garrow_record_batch_get_n_columns(record_batch); - GArrowRecordBatch* cur_record_batch = record_batch; - GError* error = NULL; - for (uintptr_t i = 0; i < partition_cols->len; i++) { - char* col = partition_cols->cols[i]; - guint pos = cols + i; - KernelStringSlice key = { col, strlen(col) }; - char* partition_val = get_from_map(partition_values, key, allocate_string); - print_diag( - " Adding partition column '%s' with value '%s' at column %u\n", - col, - partition_val ? partition_val : "NULL", - pos); - GArrowStringArrayBuilder* builder = garrow_string_array_builder_new(); - for (gint64 i = 0; i < rows; i++) { - if (partition_val) { - garrow_string_array_builder_append_string(builder, partition_val, &error); - } else { - garrow_array_builder_append_null((GArrowArrayBuilder*)builder, &error); - } - if (report_g_error("Can't append to partition column builder", error)) { - break; - } - } - - if (partition_val) { - free(partition_val); - } - - if (error != NULL) { - printf("Giving up on column %s\n", col); - g_error_free(error); - g_object_unref(builder); - error = NULL; - continue; - } - - GArrowArray* partition_col = garrow_array_builder_finish((GArrowArrayBuilder*)builder, &error); - if (report_g_error("Can't build string array for parition column", error)) { - printf("Giving up on column %s\n", col); - g_error_free(error); - g_object_unref(builder); - error = NULL; - continue; - } - g_object_unref(builder); - - GArrowDataType* string_data_type = (GArrowDataType*)garrow_string_data_type_new(); - GArrowField* field = garrow_field_new(col, string_data_type); - GArrowRecordBatch* old_batch = cur_record_batch; - cur_record_batch = garrow_record_batch_add_column(old_batch, pos, field, partition_col, &error); - g_object_unref(old_batch); - g_object_unref(partition_col); - g_object_unref(string_data_type); - g_object_unref(field); - if (cur_record_batch == NULL) { - if (error != NULL) { - printf("Could not add column at %u: %s\n", pos, error->message); - g_error_free(error); - } - } - } - return cur_record_batch; -} - // append a batch to our context static void add_batch_to_context( ArrowContext* context, - ArrowFFIData* arrow_data, - PartitionList* partition_cols, - const CStringMap* partition_values) + ArrowFFIData* arrow_data) { GArrowSchema* schema = get_schema(&arrow_data->schema); GArrowRecordBatch* record_batch = get_record_batch(&arrow_data->array, schema); @@ -142,11 +67,6 @@ static void add_batch_to_context( g_object_unref(context->cur_filter); context->cur_filter = NULL; } - record_batch = add_partition_columns(record_batch, partition_cols, partition_values); - if (record_batch == NULL) { - printf("Failed to add parition columns, not adding batch\n"); - return; - } context->batches = g_list_append(context->batches, record_batch); context->num_batches++; print_diag( @@ -187,20 +107,47 @@ static GArrowBooleanArray* slice_to_arrow_bool_array(const KernelBoolSlice slice return (GArrowBooleanArray*)ret; } +static ExclusiveEngineData* apply_transform( + struct EngineContext* context, + ExclusiveEngineData* data) { + print_diag(" Applying transform\n"); + SharedExpressionEvaluator* evaluator = get_evaluator( + context->engine, + context->read_schema, // input schema + context->arrow_context->cur_transform, + context->logical_schema); // output schema + ExternResultHandleExclusiveEngineData transformed_res = evaluate( + context->engine, + &data, + evaluator); + if (transformed_res.tag != OkHandleExclusiveEngineData) { + print_error("Failed to transform read data.", (Error*)transformed_res.err); + free_error((Error*)transformed_res.err); + return NULL; + } + free_engine_data(data); + free_evaluator(evaluator); + return transformed_res.ok; +} + // This is the callback that will be called for each chunk of data read from the parquet file static void visit_read_data(void* vcontext, ExclusiveEngineData* data) { print_diag(" Converting read data to arrow\n"); struct EngineContext* context = vcontext; - ExternResultArrowFFIData arrow_res = get_raw_arrow_data(data, context->engine); + ExclusiveEngineData* transformed = apply_transform(context, data); + if (!transformed) { + // TODO: What? + exit(-1); + } + ExternResultArrowFFIData arrow_res = get_raw_arrow_data(transformed, context->engine); if (arrow_res.tag != OkArrowFFIData) { print_error("Failed to get arrow data.", (Error*)arrow_res.err); free_error((Error*)arrow_res.err); exit(-1); } ArrowFFIData* arrow_data = arrow_res.ok; - add_batch_to_context( - context->arrow_context, arrow_data, context->partition_cols, context->partition_values); + add_batch_to_context(context->arrow_context, arrow_data); free(arrow_data); // just frees the struct, the data and schema are freed/owned by add_batch_to_context } @@ -208,7 +155,8 @@ static void visit_read_data(void* vcontext, ExclusiveEngineData* data) void c_read_parquet_file( struct EngineContext* context, const KernelStringSlice path, - const KernelBoolSlice selection_vector) + const KernelBoolSlice selection_vector, + const Expression* transform) { int full_len = strlen(context->table_root) + path.len + 1; char* full_path = malloc(sizeof(char) * full_len); @@ -233,6 +181,7 @@ void c_read_parquet_file( } context->arrow_context->cur_filter = sel_array; } + context->arrow_context->cur_transform = transform; ExclusiveFileReadResultIterator* read_iter = read_res.ok; for (;;) { ExternResultbool ok_res = read_result_next(read_iter, context, visit_read_data); diff --git a/ffi/examples/read-table/arrow.h b/ffi/examples/read-table/arrow.h index 0236b238b..8f34cdd4f 100644 --- a/ffi/examples/read-table/arrow.h +++ b/ffi/examples/read-table/arrow.h @@ -15,13 +15,15 @@ typedef struct ArrowContext gsize num_batches; GList* batches; GArrowBooleanArray* cur_filter; + const Expression* cur_transform; } ArrowContext; ArrowContext* init_arrow_context(void); void c_read_parquet_file( struct EngineContext* context, const KernelStringSlice path, - const KernelBoolSlice selection_vector); + const KernelBoolSlice selection_vector, + const Expression* transform); void print_arrow_context(ArrowContext* context); void free_arrow_context(ArrowContext* context); diff --git a/ffi/examples/read-table/read_table.c b/ffi/examples/read-table/read_table.c index 0aa8caa41..5be6a3e4e 100644 --- a/ffi/examples/read-table/read_table.c +++ b/ffi/examples/read-table/read_table.c @@ -28,7 +28,7 @@ void print_partition_info(struct EngineContext* context, const CStringMap* parti for (uintptr_t i = 0; i < context->partition_cols->len; i++) { char* col = context->partition_cols->cols[i]; KernelStringSlice key = { col, strlen(col) }; - char* partition_val = get_from_map(partition_values, key, allocate_string); + char* partition_val = get_from_string_map(partition_values, key, allocate_string); if (partition_val) { print_diag(" partition '%s' here: %s\n", col, partition_val); free(partition_val); @@ -50,6 +50,7 @@ void scan_row_callback( int64_t size, const Stats* stats, const DvInfo* dv_info, + const Expression* transform, const CStringMap* partition_values) { (void)size; // not using this at the moment @@ -76,7 +77,7 @@ void scan_row_callback( context->partition_values = partition_values; print_partition_info(context, partition_values); #ifdef PRINT_ARROW_DATA - c_read_parquet_file(context, path, selection_vector); + c_read_parquet_file(context, path, selection_vector, transform); #endif free_bool_slice(selection_vector); context->partition_values = NULL; @@ -87,14 +88,15 @@ void scan_row_callback( void do_visit_scan_data( void* engine_context, ExclusiveEngineData* engine_data, - KernelBoolSlice selection_vec) + KernelBoolSlice selection_vec, + const CTransformMap* transforms) { print_diag("\nScan iterator found some data to read\n Of this data, here is " "a selection vector\n"); print_selection_vector(" ", &selection_vec); // Ask kernel to iterate each individual file and call us back with extracted metadata print_diag("Asking kernel to call us back for each scan row (file to read)\n"); - visit_scan_data(engine_data, selection_vec, engine_context, scan_row_callback); + visit_scan_data(engine_data, selection_vec, transforms, engine_context, scan_row_callback); free_bool_slice(selection_vec); free_engine_data(engine_data); } @@ -272,10 +274,12 @@ int main(int argc, char* argv[]) SharedScan* scan = scan_res.ok; SharedGlobalScanState* global_state = get_global_scan_state(scan); + SharedSchema* logical_schema = get_global_logical_schema(global_state); SharedSchema* read_schema = get_global_read_schema(global_state); PartitionList* partition_cols = get_partition_list(global_state); struct EngineContext context = { global_state, + logical_schema, read_schema, table_root, engine, @@ -320,7 +324,8 @@ int main(int argc, char* argv[]) free_kernel_scan_data(data_iter); free_scan(scan); - free_global_read_schema(read_schema); + free_schema(logical_schema); + free_schema(read_schema); free_global_scan_state(global_state); free_snapshot(snapshot); free_engine(engine); diff --git a/ffi/examples/read-table/read_table.h b/ffi/examples/read-table/read_table.h index 28d9c72dc..cf55863d9 100644 --- a/ffi/examples/read-table/read_table.h +++ b/ffi/examples/read-table/read_table.h @@ -14,6 +14,7 @@ typedef struct PartitionList struct EngineContext { SharedGlobalScanState* global_state; + SharedSchema* logical_schema; SharedSchema* read_schema; char* table_root; SharedExternEngine* engine; diff --git a/ffi/src/engine_funcs.rs b/ffi/src/engine_funcs.rs index f8534dfc0..a2dd4a014 100644 --- a/ffi/src/engine_funcs.rs +++ b/ffi/src/engine_funcs.rs @@ -2,7 +2,10 @@ use std::sync::Arc; -use delta_kernel::{schema::Schema, DeltaResult, FileDataReadResultIterator}; +use delta_kernel::{ + schema::{DataType, Schema, SchemaRef}, + DeltaResult, EngineData, Expression, ExpressionEvaluator, FileDataReadResultIterator, +}; use delta_kernel_ffi_macros::handle_descriptor; use tracing::debug; use url::Url; @@ -97,7 +100,7 @@ pub unsafe extern "C" fn free_read_result_iter(data: Handle, + engine: Handle, // TODO Does this cause a free? file: &FileMeta, physical_schema: Handle, ) -> ExternResult> { @@ -130,3 +133,101 @@ fn read_parquet_file_impl( }); Ok(res.into()) } + +// Expression Eval + +#[handle_descriptor(target=dyn ExpressionEvaluator, mutable=false)] +pub struct SharedExpressionEvaluator; + +#[no_mangle] +pub unsafe extern "C" fn get_evaluator( + engine: Handle, + input_schema: Handle, + expression: &Expression, + // TODO: Make this a data_type, and give a way for c code to go between schema <-> datatype + output_type: Handle, +) -> Handle { + let engine = unsafe { engine.clone_as_arc() }; + let input_schema = unsafe { input_schema.clone_as_arc() }; + let output_type: DataType = output_type.as_ref().clone().into(); + get_evaluator_impl(engine, input_schema, expression, output_type) +} + +fn get_evaluator_impl( + extern_engine: Arc, + input_schema: SchemaRef, + expression: &Expression, + output_type: DataType, +) -> Handle { + let engine = extern_engine.engine(); + let evaluator = engine.get_expression_handler().get_evaluator( + input_schema, + expression.clone(), + output_type, + ); + evaluator.into() +} + +/// Free an evaluator +/// # Safety +/// +/// Caller is responsible for passing a valid handle. +#[no_mangle] +pub unsafe extern "C" fn free_evaluator(evaluator: Handle) { + debug!("engine released evaluator"); + evaluator.drop_handle(); +} + + +#[no_mangle] +pub unsafe extern "C" fn evaluate( + engine: Handle, + batch: &mut Handle, + evaluator: Handle, +) -> ExternResult> { + let engine = unsafe { engine.clone_as_arc() }; + let batch = unsafe { batch.as_mut() }; + let evaluator = unsafe { evaluator.clone_as_arc() }; + let res = evaluate_impl(batch, evaluator.as_ref()); + res.into_extern_result(&engine.as_ref()) +} + +fn evaluate_impl( + batch: &dyn EngineData, + evaluator: &dyn ExpressionEvaluator, +) -> DeltaResult> { + let res = evaluator.evaluate(batch); + res.map(|d| d.into()) +} + +#[cfg(test)] +mod tests { + use super::get_evaluator; + use crate::{free_engine, tests::get_default_engine}; + use delta_kernel::{ + schema::{DataType, StructField, StructType}, + Expression, + }; + use std::sync::Arc; + + #[test] + fn test_get_evaluator() { + let engine = get_default_engine(); + let in_schema = Arc::new(StructType::new(vec![StructField::new( + "a", + DataType::LONG, + true, + )])); + let expr = Expression::literal(1); + let output_type = in_schema.clone(); + unsafe { + get_evaluator( + engine.shallow_copy(), + in_schema.into(), + &expr, + output_type.into(), + ); + free_engine(engine); + } + } +} diff --git a/ffi/src/expressions/kernel.rs b/ffi/src/expressions/kernel.rs index f2ed8b1a3..10856735f 100644 --- a/ffi/src/expressions/kernel.rs +++ b/ffi/src/expressions/kernel.rs @@ -189,6 +189,29 @@ pub struct EngineExpressionVisitor { pub unsafe extern "C" fn visit_expression( expression: &Handle, visitor: &mut EngineExpressionVisitor, +) -> usize { + visit_expression_internal(expression.as_ref(), visitor) +} + +/// Visit the expression of the passed [`Expression`] pointer using the provided `visitor`. See the +/// documentation of [`EngineExpressionVisitor`] for a description of how this visitor works. +/// +/// This method returns the id that the engine generated for the top level expression +/// +/// # Safety +/// +/// The caller must pass a valid Expression pointer and expression visitor +#[no_mangle] +pub unsafe extern "C" fn visit_expression_ref( + expression: &Expression, + visitor: &mut EngineExpressionVisitor, +) -> usize { + visit_expression_internal(expression, visitor) +} + +pub fn visit_expression_internal( + expression: &Expression, + visitor: &mut EngineExpressionVisitor, ) -> usize { macro_rules! call { ( $visitor:ident, $visitor_fn:ident $(, $extra_args:expr) *) => { @@ -367,6 +390,6 @@ pub unsafe extern "C" fn visit_expression( } } let top_level = call!(visitor, make_field_list, 1); - visit_expression_impl(visitor, expression.as_ref(), top_level); + visit_expression_impl(visitor, expression, top_level); top_level } diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs index 323f02ac9..61c4a3db1 100644 --- a/ffi/src/lib.rs +++ b/ffi/src/lib.rs @@ -330,7 +330,7 @@ pub unsafe extern "C" fn free_row_indexes(slice: KernelRowIndexArray) { /// an opaque struct that encapsulates data read by an engine. this handle can be passed back into /// some kernel calls to operate on the data, or can be converted into the raw data as read by the /// [`delta_kernel::Engine`] by calling [`get_raw_engine_data`] -#[handle_descriptor(target=dyn EngineData, mutable=true, sized=false)] +#[handle_descriptor(target=dyn EngineData, mutable=true)] pub struct ExclusiveEngineData; /// Drop an `ExclusiveEngineData`. @@ -767,7 +767,7 @@ mod tests { } } - fn get_default_engine() -> Handle { + pub(crate) fn get_default_engine() -> Handle { let path = "memory:///doesntmatter/foo"; let path = kernel_string_slice!(path); let builder = unsafe { ok_or_panic(get_engine_builder(path, allocate_err)) }; diff --git a/ffi/src/scan.rs b/ffi/src/scan.rs index d5695c130..c53dc968f 100644 --- a/ffi/src/scan.rs +++ b/ffi/src/scan.rs @@ -7,7 +7,7 @@ use delta_kernel::scan::state::{visit_scan_files, DvInfo, GlobalScanState}; use delta_kernel::scan::{Scan, ScanData}; use delta_kernel::schema::Schema; use delta_kernel::snapshot::Snapshot; -use delta_kernel::{DeltaResult, Error}; +use delta_kernel::{DeltaResult, Error, Expression, ExpressionRef}; use delta_kernel_ffi_macros::handle_descriptor; use tracing::debug; use url::Url; @@ -15,6 +15,7 @@ use url::Url; use crate::expressions::engine::{ unwrap_kernel_expression, EnginePredicate, KernelExpressionVisitorState, }; +use crate::expressions::SharedExpression; use crate::{ kernel_string_slice, AllocateStringFn, ExclusiveEngineData, ExternEngine, ExternResult, IntoExternResult, KernelBoolSlice, KernelRowIndexArray, KernelStringSlice, NullableCvoid, @@ -99,12 +100,26 @@ pub unsafe extern "C" fn get_global_read_schema( state.physical_schema.clone().into() } -/// Free a global read schema +/// Get the kernel view of the physical read schema that an engine should read from parquet file in +/// a scan +/// +/// # Safety +/// Engine is responsible for providing a valid GlobalScanState pointer +#[no_mangle] +pub unsafe extern "C" fn get_global_logical_schema( + state: Handle, +) -> Handle { + let state = unsafe { state.as_ref() }; + state.logical_schema.clone().into() +} + + +/// Free a schema /// /// # Safety /// Engine is responsible for providing a valid schema obtained via [`get_global_read_schema`] #[no_mangle] -pub unsafe extern "C" fn free_global_read_schema(schema: Handle) { +pub unsafe extern "C" fn free_schema(schema: Handle) { schema.drop_handle(); } @@ -211,6 +226,7 @@ pub unsafe extern "C" fn kernel_scan_data_next( engine_context: NullableCvoid, engine_data: Handle, selection_vector: KernelBoolSlice, + transforms: &CTransformMap, ), ) -> ExternResult { let data = unsafe { data.as_ref() }; @@ -224,15 +240,17 @@ fn kernel_scan_data_next_impl( engine_context: NullableCvoid, engine_data: Handle, selection_vector: KernelBoolSlice, + transforms: &CTransformMap, ), ) -> DeltaResult { let mut data = data .data .lock() .map_err(|_| Error::generic("poisoned mutex"))?; - if let Some((data, sel_vec)) = data.next().transpose()? { + if let Some((data, sel_vec, transforms)) = data.next().transpose()? { let bool_slice = KernelBoolSlice::from(sel_vec); - (engine_visitor)(engine_context, data.into(), bool_slice); + let transform_map = CTransformMap { transforms }; + (engine_visitor)(engine_context, data.into(), bool_slice, &transform_map); Ok(true) } else { Ok(false) @@ -266,6 +284,7 @@ type CScanCallback = extern "C" fn( size: i64, stats: Option<&Stats>, dv_info: &DvInfo, + transform: Option<&Expression>, partition_map: &CStringMap, ); @@ -281,7 +300,7 @@ pub struct CStringMap { /// # Safety /// /// The engine is responsible for providing a valid [`CStringMap`] pointer and [`KernelStringSlice`] -pub unsafe extern "C" fn get_from_map( +pub unsafe extern "C" fn get_from_string_map( map: &CStringMap, key: KernelStringSlice, allocate_fn: AllocateStringFn, @@ -293,6 +312,30 @@ pub unsafe extern "C" fn get_from_map( .and_then(|v| allocate_fn(kernel_string_slice!(v))) } +pub struct CTransformMap { + transforms: HashMap, +} + + +#[no_mangle] +/// allow probing into a CTransformMap. If the specified row id is in the map, kernel will return a +/// handle to the transform expression for that row. If the row id is not in the map, this will +/// return NULL +/// +/// # Safety +/// +/// The engine is responsible for providing a valid [`CTransformMap`] pointer +pub unsafe extern "C" fn get_from_transform_map( + transform_map: &CTransformMap, + row: usize, +) -> Handle { + if let Some(transform) = transform_map.transforms.get(&row).cloned() { + transform.into() + } else { + panic!("Hrmm"); + } +} + /// Get a selection vector out of a [`DvInfo`] struct /// /// # Safety @@ -355,8 +398,10 @@ fn rust_callback( size: i64, kernel_stats: Option, dv_info: DvInfo, + transform: Option, partition_values: HashMap, ) { + let transform = transform.map(|e| e.as_ref().clone()); let partition_map = CStringMap { values: partition_values, }; @@ -369,6 +414,7 @@ fn rust_callback( size, stats.as_ref(), &dv_info, + transform.as_ref(), &partition_map, ); } @@ -388,6 +434,7 @@ struct ContextWrapper { pub unsafe extern "C" fn visit_scan_data( data: Handle, selection_vec: KernelBoolSlice, + transforms: &CTransformMap, engine_context: NullableCvoid, callback: CScanCallback, ) { @@ -398,5 +445,12 @@ pub unsafe extern "C" fn visit_scan_data( callback, }; // TODO: return ExternResult to caller instead of panicking? - visit_scan_files(data, selection_vec, context_wrapper, rust_callback).unwrap(); + visit_scan_files( + data, + selection_vec, + &transforms.transforms, + context_wrapper, + rust_callback, + ) + .unwrap(); } diff --git a/kernel/examples/inspect-table/src/main.rs b/kernel/examples/inspect-table/src/main.rs index ea25a8404..e7c577042 100644 --- a/kernel/examples/inspect-table/src/main.rs +++ b/kernel/examples/inspect-table/src/main.rs @@ -12,7 +12,7 @@ use delta_kernel::expressions::ColumnName; use delta_kernel::scan::state::{DvInfo, Stats}; use delta_kernel::scan::ScanBuilder; use delta_kernel::schema::{ColumnNamesAndTypes, DataType}; -use delta_kernel::{DeltaResult, Error, Table}; +use delta_kernel::{DeltaResult, Error, ExpressionRef, Table}; use std::collections::HashMap; use std::process::ExitCode; @@ -163,6 +163,7 @@ fn print_scan_file( size: i64, stats: Option, dv_info: DvInfo, + transform: Option, partition_values: HashMap, ) { let num_record_str = if let Some(s) = stats { @@ -176,6 +177,7 @@ fn print_scan_file( Size (bytes):\t{size}\n \ Num Records:\t{num_record_str}\n \ Has DV?:\t{}\n \ + Transform:\t{transform:?}\n \ Part Vals:\t{partition_values:?}", dv_info.has_vector() ); @@ -209,10 +211,11 @@ fn try_main() -> DeltaResult<()> { let scan = ScanBuilder::new(snapshot).build()?; let scan_data = scan.scan_data(&engine)?; for res in scan_data { - let (data, vector) = res?; + let (data, vector, transforms) = res?; delta_kernel::scan::state::visit_scan_files( data.as_ref(), &vector, + &transforms, (), print_scan_file, )?; diff --git a/kernel/examples/read-table-multi-threaded/src/main.rs b/kernel/examples/read-table-multi-threaded/src/main.rs index d97b6c2d3..90a7b6ba6 100644 --- a/kernel/examples/read-table-multi-threaded/src/main.rs +++ b/kernel/examples/read-table-multi-threaded/src/main.rs @@ -13,9 +13,8 @@ use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; use delta_kernel::engine::sync::SyncEngine; use delta_kernel::scan::state::{DvInfo, GlobalScanState, Stats}; -use delta_kernel::scan::transform_to_logical; use delta_kernel::schema::Schema; -use delta_kernel::{DeltaResult, Engine, EngineData, FileMeta, Table}; +use delta_kernel::{DeltaResult, Engine, EngineData, ExpressionRef, FileMeta, Table}; use clap::{Parser, ValueEnum}; use url::Url; @@ -81,7 +80,7 @@ fn main() -> ExitCode { struct ScanFile { path: String, size: i64, - partition_values: HashMap, + transform: Option, dv_info: DvInfo, } @@ -111,12 +110,13 @@ fn send_scan_file( size: i64, _stats: Option, dv_info: DvInfo, - partition_values: HashMap, + transform: Option, + _: HashMap, ) { let scan_file = ScanFile { path: path.to_string(), size, - partition_values, + transform, dv_info, }; scan_tx.send(scan_file).unwrap(); @@ -210,10 +210,11 @@ fn try_main() -> DeltaResult<()> { drop(record_batch_tx); for res in scan_data { - let (data, vector) = res?; + let (data, vector, transforms) = res?; scan_file_tx = delta_kernel::scan::state::visit_scan_files( data.as_ref(), &vector, + &transforms, scan_file_tx, send_scan_file, )?; @@ -256,7 +257,6 @@ fn do_work( ) { // get the type for the function calls let engine: &dyn Engine = engine.as_ref(); - let physical_schema = scan_state.physical_schema.clone(); // in a loop, try and get a ScanFile. Note that `recv` will return an `Err` when the other side // hangs up, which indicates there's no more data to process. while let Ok(scan_file) = scan_file_rx.recv() { @@ -287,21 +287,26 @@ fn do_work( // vector let read_results = engine .get_parquet_handler() - .read_parquet_files(&[meta], physical_schema.clone(), None) + .read_parquet_files(&[meta], scan_state.physical_schema.clone(), None) .unwrap(); for read_result in read_results { let read_result = read_result.unwrap(); let len = read_result.len(); - - // ask the kernel to transform the physical data into the correct logical form - let logical = transform_to_logical( - engine, - read_result, - &scan_state, - &scan_file.partition_values, - ) - .unwrap(); + // to transform the physical data into the correct logical form + let logical = if let Some(ref transform) = scan_file.transform { + engine + .get_expression_handler() + .get_evaluator( + scan_state.physical_schema.clone(), + transform.as_ref().clone(), // TODO: Maybe eval should take a ref + scan_state.logical_schema.clone().into(), + ) + .evaluate(read_result.as_ref()) + .unwrap() + } else { + read_result + }; let record_batch = to_arrow(logical).unwrap(); diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index 8ee54ebd0..a9f00dca6 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -21,6 +21,7 @@ use arrow_schema::{ }; use arrow_select::concat::concat; use itertools::Itertools; +use tracing::debug; use super::arrow_conversion::LIST_ARRAY_ROOT; use super::arrow_utils::make_arrow_error; @@ -537,6 +538,7 @@ pub struct DefaultExpressionEvaluator { impl ExpressionEvaluator for DefaultExpressionEvaluator { fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult> { + debug!("Arrow evaluator evaluating: {:#?}", self.expression.as_ref()); let batch = batch .any_ref() .downcast_ref::() diff --git a/kernel/src/engine_data.rs b/kernel/src/engine_data.rs index e421d0ad6..701461a95 100644 --- a/kernel/src/engine_data.rs +++ b/kernel/src/engine_data.rs @@ -129,7 +129,9 @@ pub trait TypedGetData<'a, T> { fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult>; fn get(&'a self, row_index: usize, field_name: &str) -> DeltaResult { let val = self.get_opt(row_index, field_name)?; - val.ok_or_else(|| Error::MissingData(format!("Data missing for field {field_name}"))) + val.ok_or_else(|| { + Error::MissingData(format!("Data missing for field {field_name}")).with_backtrace() + }) } } diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index fb5c2b0fa..5d316f167 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -1,15 +1,16 @@ use std::clone::Clone; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, LazyLock}; +use itertools::Itertools; use tracing::debug; use super::data_skipping::DataSkippingFilter; -use super::ScanData; +use super::{ScanData, Transform}; use crate::actions::get_log_add_schema; use crate::engine_data::{GetData, RowVisitor, TypedGetData as _}; use crate::expressions::{column_expr, column_name, ColumnName, Expression, ExpressionRef}; -use crate::scan::DeletionVectorDescriptor; +use crate::scan::{DeletionVectorDescriptor, TransformExpr}; use crate::schema::{ColumnNamesAndTypes, DataType, MapType, SchemaRef, StructField, StructType}; use crate::utils::require; use crate::{DeltaResult, Engine, EngineData, Error, ExpressionEvaluator}; @@ -44,12 +45,17 @@ struct LogReplayScanner { struct AddRemoveDedupVisitor<'seen> { seen: &'seen mut HashSet, selection_vector: Vec, + logical_schema: SchemaRef, + transform: Option>, + transforms: HashMap, is_log_batch: bool, } impl AddRemoveDedupVisitor<'_> { /// Checks if log replay already processed this logical file (in which case the current action /// should be ignored). If not already seen, register it so we can recognize future duplicates. + /// Returns `true` if we have seen the file and should ignore it, `false` if we have not seen it + /// and should process it. fn check_and_record_seen(&mut self, key: FileActionKey) -> bool { // Note: each (add.path + add.dv_unique_id()) pair has a // unique Add + Remove pair in the log. For example: @@ -83,11 +89,11 @@ impl AddRemoveDedupVisitor<'_> { // have a remove with a path at index 4. In either case, extract the three dv getters at // indexes that immediately follow a valid path index. let (path, dv_getters, is_add) = if let Some(path) = getters[0].get_str(i, "add.path")? { - (path, &getters[1..4], true) + (path, &getters[2..5], true) } else if !self.is_log_batch { return Ok(false); - } else if let Some(path) = getters[4].get_opt(i, "remove.path")? { - (path, &getters[5..8], false) + } else if let Some(path) = getters[5].get_opt(i, "remove.path")? { + (path, &getters[6..9], false) } else { return Ok(false); }; @@ -103,7 +109,34 @@ impl AddRemoveDedupVisitor<'_> { // Process both adds and removes, but only return not already-seen adds let file_key = FileActionKey::new(path, dv_unique_id); - Ok(!self.check_and_record_seen(file_key) && is_add) + let have_seen = self.check_and_record_seen(file_key); + if is_add && !have_seen { + // compute transform here + if let Some(ref transform) = self.transform { + let partition_values: HashMap<_, _> = getters[1].get(i, "add.partitionValues")?; + let transforms = transform + .iter() + .map(|transform_expr| match transform_expr { + TransformExpr::Partition(field_idx) => { + let field = self.logical_schema.fields.get_index(*field_idx); + let Some((_, field)) = field else { + return Err(Error::generic( + "logical schema did not contain expected field, can't transform data", + )); + }; + let name = field.physical_name(); + let value_expression = + super::parse_partition_value(partition_values.get(name), field.data_type())?; + Ok(value_expression.into()) + } + TransformExpr::Static(field_expr) => Ok(field_expr.clone()), + }) + .try_collect()?; + self.transforms + .insert(i, Arc::new(Expression::Struct(transforms))); + } + } + Ok(!have_seen && is_add) } } @@ -113,8 +146,10 @@ impl RowVisitor for AddRemoveDedupVisitor<'_> { static NAMES_AND_TYPES: LazyLock = LazyLock::new(|| { const STRING: DataType = DataType::STRING; const INTEGER: DataType = DataType::INTEGER; + let ss_map: DataType = MapType::new(STRING, STRING, true).into(); let types_and_names = vec![ (STRING, column_name!("add.path")), + (ss_map, column_name!("add.partitionValues")), (STRING, column_name!("add.deletionVector.storageType")), (STRING, column_name!("add.deletionVector.pathOrInlineDv")), (INTEGER, column_name!("add.deletionVector.offset")), @@ -132,12 +167,12 @@ impl RowVisitor for AddRemoveDedupVisitor<'_> { } else { // All checkpoint actions are already reconciled and Remove actions in checkpoint files // only serve as tombstones for vacuum jobs. So we only need to examine the adds here. - (&names[..4], &types[..4]) + (&names[..5], &types[..5]) } } fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()> { - let expected_getters = if self.is_log_batch { 8 } else { 4 }; + let expected_getters = if self.is_log_batch { 9 } else { 5 }; require!( getters.len() == expected_getters, Error::InternalError(format!( @@ -207,6 +242,8 @@ impl LogReplayScanner { &mut self, add_transform: &dyn ExpressionEvaluator, actions: &dyn EngineData, + logical_schema: SchemaRef, + transform: Option>, is_log_batch: bool, ) -> DeltaResult { // Apply data skipping to get back a selection vector for actions that passed skipping. We @@ -220,6 +257,9 @@ impl LogReplayScanner { let mut visitor = AddRemoveDedupVisitor { seen: &mut self.seen, selection_vector, + logical_schema, + transform, + transforms: HashMap::new(), is_log_batch, }; visitor.visit_rows_of(actions)?; @@ -227,7 +267,7 @@ impl LogReplayScanner { // TODO: Teach expression eval to respect the selection vector we just computed so carefully! let selection_vector = visitor.selection_vector; let result = add_transform.evaluate(actions)?; - Ok((result, selection_vector)) + Ok((result, selection_vector, visitor.transforms)) } } @@ -235,9 +275,11 @@ impl LogReplayScanner { /// `(engine_data, selection_vec)`. Each row that is selected in the returned `engine_data` _must_ /// be processed to complete the scan. Non-selected rows _must_ be ignored. The boolean flag /// indicates whether the record batch is a log or checkpoint batch. -pub fn scan_action_iter( +pub(crate) fn scan_action_iter( engine: &dyn Engine, action_iter: impl Iterator, bool)>>, + logical_schema: SchemaRef, + transform: Option>, physical_predicate: Option<(ExpressionRef, SchemaRef)>, ) -> impl Iterator> { let mut log_scanner = LogReplayScanner::new(engine, physical_predicate); @@ -249,20 +291,40 @@ pub fn scan_action_iter( action_iter .map(move |action_res| { let (batch, is_log_batch) = action_res?; - log_scanner.process_scan_batch(add_transform.as_ref(), batch.as_ref(), is_log_batch) + log_scanner.process_scan_batch( + add_transform.as_ref(), + batch.as_ref(), + logical_schema.clone(), + transform.clone(), + is_log_batch, + ) }) - .filter(|res| res.as_ref().map_or(true, |(_, sv)| sv.contains(&true))) + .filter(|res| res.as_ref().map_or(true, |(_, sv, _)| sv.contains(&true))) } #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::{collections::HashMap, sync::Arc}; + use crate::expressions::{column_name, Scalar}; use crate::scan::{ + get_state_info, state::{DvInfo, Stats}, - test_utils::{add_batch_simple, add_batch_with_remove, run_with_validate_callback}, + test_utils::{ + add_batch_simple, add_batch_with_partition_col, add_batch_with_remove, + run_with_validate_callback, + }, + Scan, + }; + use crate::Expression; + use crate::{ + engine::sync::SyncEngine, + schema::{DataType, SchemaRef, StructField, StructType}, + ExpressionRef, }; + use super::scan_action_iter; + // dv-info is more complex to validate, we validate that works in the test for visit_scan_files // in state.rs fn validate_simple( @@ -271,6 +333,7 @@ mod tests { size: i64, stats: Option, _: DvInfo, + _: Option, part_vals: HashMap, ) { assert_eq!( @@ -288,6 +351,8 @@ mod tests { fn test_scan_action_iter() { run_with_validate_callback( vec![add_batch_simple()], + None, // not testing schema + None, // not testing transform &[true, false], (), validate_simple, @@ -298,9 +363,76 @@ mod tests { fn test_scan_action_iter_with_remove() { run_with_validate_callback( vec![add_batch_with_remove()], + None, // not testing schema + None, // not testing transform &[false, false, true, false], (), validate_simple, ); } + + #[test] + fn test_no_transforms() { + let batch = vec![add_batch_simple()]; + let logical_schema = Arc::new(crate::schema::StructType::new(vec![])); + let iter = scan_action_iter( + &SyncEngine::new(), + batch.into_iter().map(|batch| Ok((batch as _, true))), + logical_schema, + None, + None, + ); + for res in iter { + let (_batch, _sel, transforms) = res.unwrap(); + assert!(transforms.is_empty(), "Should have no transforms"); + } + } + + #[test] + fn test_simple_transform() { + let schema: SchemaRef = Arc::new(StructType::new([ + StructField::new("value", DataType::INTEGER, true), + StructField::new("date", DataType::DATE, true), + ])); + let partition_cols = ["date".to_string()]; + let state_info = get_state_info(schema.as_ref(), &partition_cols).unwrap(); + let static_transform = Some(Arc::new(Scan::get_static_transform(&state_info.all_fields))); + let batch = vec![add_batch_with_partition_col()]; + let iter = scan_action_iter( + &SyncEngine::new(), + batch.into_iter().map(|batch| Ok((batch as _, true))), + schema, + static_transform, + None, + ); + + fn validate_transform(transform: Option<&ExpressionRef>, expected_date_offset: i32) { + assert!(transform.is_some()); + if let Expression::Struct(inner) = transform.unwrap().as_ref() { + if let Expression::Column(ref name) = inner[0] { + assert_eq!(name, &column_name!("value"), "First col should be 'value'"); + } else { + panic!("Expected first expression to be a column"); + } + if let Expression::Literal(ref scalar) = inner[1] { + assert_eq!( + scalar, + &Scalar::Date(expected_date_offset), + "Didn't get expected date offset" + ); + } else { + panic!("Expected second expression to be a literal"); + } + } else { + panic!("Transform should always be a struct expr"); + } + } + + for res in iter { + let (_batch, _sel, transforms) = res.unwrap(); + assert_eq!(transforms.len(), 2, "Should have two transforms"); + validate_transform(transforms.get(&0), 17511); + validate_transform(transforms.get(&1), 17510); + } + } } diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index e0d345b56..81f57e3a1 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -320,7 +320,24 @@ pub enum ColumnType { Partition(usize), } -pub type ScanData = (Box, Vec); +/// A transform is ultimately a `Struct` expr. This holds the set of expressions that make that struct expr up +type Transform = Vec; + +/// Transforms aren't computed all at once. So static ones can just go straight to `Expression`, but +/// things like partition columns need to filled in. This enum holds an expression that's part of a +/// `Transform`. +pub(crate) enum TransformExpr { + Static(Expression), + Partition(usize), +} + +// TODO(nick): Make this a struct in a follow-on PR +// (data, deletion_vec, transforms) +pub type ScanData = ( + Box, + Vec, + HashMap, +); /// The result of building a scan over a table. This can be used to get the actual data from /// scanning the table. @@ -359,6 +376,21 @@ impl Scan { } } + /// Convert the parts of the transform that can be computed statically into `Expression`s. For + /// parts that cannot be computed statically, include enough metadata so lower levels of + /// processing can create and fill in an expression. + fn get_static_transform(all_fields: &[ColumnType]) -> Transform { + all_fields + .iter() + .map(|field| match field { + ColumnType::Selected(col_name) => { + TransformExpr::Static(ColumnName::new([col_name]).into()) + } + ColumnType::Partition(idx) => TransformExpr::Partition(*idx), + }) + .collect() + } + /// Get an iterator of [`EngineData`]s that should be included in scan for a query. This handles /// log-replay, reconciling Add and Remove actions, and applying data skipping (if /// possible). Each item in the returned iterator is a tuple of: @@ -371,11 +403,26 @@ impl Scan { /// the query. NB: If you are using the default engine and plan to call arrow's /// `filter_record_batch`, you _need_ to extend this vector to the full length of the batch or /// arrow will drop the extra rows. + /// - `HashMap`: Transformation expressions that need to be applied. For each + /// row at index `i` in the above data, if an expression exists in this map for key `i`, the + /// associated expression _must_ be applied to the data read from the file specified by the + /// row. The resultant schema for this expression is guaranteed to be `Scan.schema()`. If + /// there is no entry for a row `i` in this map, no expression need be applied and the data + /// read from disk is already in the correct logical state. pub fn scan_data( &self, engine: &dyn Engine, ) -> DeltaResult>> { - // NOTE: This is a cheap arc clone + // Compute the static part of the transformation. This is `None` if no transformation is + // needed (currently just means no partition cols, but will be extended for other transforms + // as we support them) + let static_transform = if self.have_partition_cols + || self.snapshot.column_mapping_mode != ColumnMappingMode::None + { + Some(Arc::new(Scan::get_static_transform(&self.all_fields))) + } else { + None + }; let physical_predicate = match self.physical_predicate.clone() { PhysicalPredicate::StaticSkipAll => return Ok(None.into_iter().flatten()), PhysicalPredicate::Some(predicate, schema) => Some((predicate, schema)), @@ -384,6 +431,8 @@ impl Scan { let it = scan_action_iter( engine, self.replay_for_scan_data(engine)?, + self.logical_schema.clone(), + static_transform, physical_predicate, ); Ok(Some(it).into_iter().flatten()) @@ -432,7 +481,7 @@ impl Scan { path: String, size: i64, dv_info: DvInfo, - partition_values: HashMap, + transform: Option, } fn scan_data_callback( batches: &mut Vec, @@ -440,13 +489,14 @@ impl Scan { size: i64, _: Option, dv_info: DvInfo, - partition_values: HashMap, + transform: Option, + _: HashMap, ) { batches.push(ScanFile { path: path.to_string(), size, dv_info, - partition_values, + transform, }); } @@ -458,15 +508,19 @@ impl Scan { let global_state = Arc::new(self.global_scan_state()); let table_root = self.snapshot.table_root.clone(); let physical_predicate = self.physical_predicate(); - let all_fields = self.all_fields.clone(); - let have_partition_cols = self.have_partition_cols; let scan_data = self.scan_data(engine.as_ref())?; let scan_files_iter = scan_data .map(|res| { - let (data, vec) = res?; + let (data, vec, transforms) = res?; let scan_files = vec![]; - state::visit_scan_files(data.as_ref(), &vec, scan_files, scan_data_callback) + state::visit_scan_files( + data.as_ref(), + &vec, + &transforms, + scan_files, + scan_data_callback, + ) }) // Iterator>> to Iterator> .flatten_ok(); @@ -497,18 +551,21 @@ impl Scan { // Arc clones let engine = engine.clone(); let global_state = global_state.clone(); - let all_fields = all_fields.clone(); Ok(read_result_iter.map(move |read_result| -> DeltaResult<_> { let read_result = read_result?; // to transform the physical data into the correct logical form - let logical = transform_to_logical_internal( - engine.as_ref(), - read_result, - &global_state, - &scan_file.partition_values, - &all_fields, - have_partition_cols, - ); + let logical = if let Some(ref transform) = scan_file.transform { + engine + .get_expression_handler() + .get_evaluator( + global_state.physical_schema.clone(), + transform.as_ref().clone(), // TODO: Maybe eval should take a ref + global_state.logical_schema.clone().into(), + ) + .evaluate(read_result.as_ref()) + } else { + Ok(read_result) + }; let len = logical.as_ref().map_or(0, |res| res.len()); // need to split the dv_mask. what's left in dv_mask covers this result, and rest // will cover the following results. we `take()` out of `selection_vector` to avoid @@ -625,73 +682,6 @@ pub fn selection_vector( Ok(deletion_treemap_to_bools(dv_treemap)) } -/// Transform the raw data read from parquet into the correct logical form, based on the provided -/// global scan state and partition values -pub fn transform_to_logical( - engine: &dyn Engine, - data: Box, - global_state: &GlobalScanState, - partition_values: &HashMap, -) -> DeltaResult> { - let state_info = get_state_info( - &global_state.logical_schema, - &global_state.partition_columns, - )?; - transform_to_logical_internal( - engine, - data, - global_state, - partition_values, - &state_info.all_fields, - state_info.have_partition_cols, - ) -} - -// We have this function because `execute` can save `all_fields` and `have_partition_cols` in the -// scan, and then reuse them for each batch transform -fn transform_to_logical_internal( - engine: &dyn Engine, - data: Box, - global_state: &GlobalScanState, - partition_values: &std::collections::HashMap, - all_fields: &[ColumnType], - have_partition_cols: bool, -) -> DeltaResult> { - let physical_schema = global_state.physical_schema.clone(); - if !have_partition_cols && global_state.column_mapping_mode == ColumnMappingMode::None { - return Ok(data); - } - // need to add back partition cols and/or fix-up mapped columns - let all_fields = all_fields - .iter() - .map(|field| match field { - ColumnType::Partition(field_idx) => { - let field = global_state.logical_schema.fields.get_index(*field_idx); - let Some((_, field)) = field else { - return Err(Error::generic( - "logical schema did not contain expected field, can't transform data", - )); - }; - let name = field.physical_name(); - let value_expression = - parse_partition_value(partition_values.get(name), field.data_type())?; - Ok(value_expression.into()) - } - ColumnType::Selected(field_name) => Ok(ColumnName::new([field_name]).into()), - }) - .try_collect()?; - let read_expression = Expression::Struct(all_fields); - let result = engine - .get_expression_handler() - .get_evaluator( - physical_schema, - read_expression, - global_state.logical_schema.clone().into(), - ) - .evaluate(data.as_ref())?; - Ok(result) -} - // some utils that are used in file_stream.rs and state.rs tests #[cfg(test)] pub(crate) mod test_utils { @@ -707,10 +697,11 @@ pub(crate) mod test_utils { sync::{json::SyncJsonHandler, SyncEngine}, }, scan::log_replay::scan_action_iter, + schema::SchemaRef, EngineData, JsonHandler, }; - use super::state::ScanCallback; + use super::{state::ScanCallback, Transform}; // TODO(nick): Merge all copies of this into one "test utils" thing fn string_array_to_engine_data(string_array: StringArray) -> Box { @@ -753,25 +744,50 @@ pub(crate) mod test_utils { ArrowEngineData::try_from_engine_data(parsed).unwrap() } + // add batch with a `date` partition col + pub(crate) fn add_batch_with_partition_col() -> Box { + let handler = SyncJsonHandler {}; + let json_strings: StringArray = vec![ + r#"{"add":{"path":"part-00000-fae5310a-a37d-4e51-827b-c3d5516560ca-c001.snappy.parquet","partitionValues": {"date": "2017-12-11"},"size":635,"modificationTime":1677811178336,"dataChange":true,"stats":"{\"numRecords\":10,\"minValues\":{\"value\":0},\"maxValues\":{\"value\":9},\"nullCount\":{\"value\":0},\"tightBounds\":false}","tags":{"INSERTION_TIME":"1677811178336000","MIN_INSERTION_TIME":"1677811178336000","MAX_INSERTION_TIME":"1677811178336000","OPTIMIZE_TARGET_SIZE":"268435456"}}}"#, + r#"{"add":{"path":"part-00000-fae5310a-a37d-4e51-827b-c3d5516560ca-c000.snappy.parquet","partitionValues": {"date": "2017-12-10"},"size":635,"modificationTime":1677811178336,"dataChange":true,"stats":"{\"numRecords\":10,\"minValues\":{\"value\":0},\"maxValues\":{\"value\":9},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1677811178336000","MIN_INSERTION_TIME":"1677811178336000","MAX_INSERTION_TIME":"1677811178336000","OPTIMIZE_TARGET_SIZE":"268435456"},"deletionVector":{"storageType":"u","pathOrInlineDv":"vBn[lx{q8@P<9BNH/isA","offset":1,"sizeInBytes":36,"cardinality":2}}}"#, + r#"{"metaData":{"id":"testId","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"value\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["date"],"configuration":{"delta.enableDeletionVectors":"true","delta.columnMapping.mode":"none"},"createdTime":1677811175819}}"#, + ] + .into(); + let output_schema = get_log_schema().clone(); + let parsed = handler + .parse_json(string_array_to_engine_data(json_strings), output_schema) + .unwrap(); + ArrowEngineData::try_from_engine_data(parsed).unwrap() + } + + /// Create a scan action iter and validate what's called back. If you pass `None` as + /// `logical_schema`, `transform` should also be `None` #[allow(clippy::vec_box)] pub(crate) fn run_with_validate_callback( batch: Vec>, + logical_schema: Option, + transform: Option>, expected_sel_vec: &[bool], context: T, validate_callback: ScanCallback, ) { + let logical_schema = + logical_schema.unwrap_or_else(|| Arc::new(crate::schema::StructType::new(vec![]))); let iter = scan_action_iter( &SyncEngine::new(), batch.into_iter().map(|batch| Ok((batch as _, true))), + logical_schema, + transform, None, ); let mut batch_count = 0; for res in iter { - let (batch, sel) = res.unwrap(); + let (batch, sel, transforms) = res.unwrap(); assert_eq!(sel, expected_sel_vec); crate::scan::state::visit_scan_files( batch.as_ref(), &sel, + &transforms, context.clone(), validate_callback, ) @@ -975,6 +991,7 @@ mod tests { _size: i64, _: Option, dv_info: DvInfo, + _transform: Option, _partition_values: HashMap, ) { paths.push(path.to_string()); @@ -982,8 +999,14 @@ mod tests { } let mut files = vec![]; for data in scan_data { - let (data, vec) = data?; - files = state::visit_scan_files(data.as_ref(), &vec, files, scan_data_callback)?; + let (data, vec, transforms) = data?; + files = state::visit_scan_files( + data.as_ref(), + &vec, + &transforms, + files, + scan_data_callback, + )?; } Ok(files) } diff --git a/kernel/src/scan/state.rs b/kernel/src/scan/state.rs index b57f0c120..d5c1cb9e8 100644 --- a/kernel/src/scan/state.rs +++ b/kernel/src/scan/state.rs @@ -5,6 +5,7 @@ use std::sync::LazyLock; use crate::actions::deletion_vector::deletion_treemap_to_bools; use crate::utils::require; +use crate::ExpressionRef; use crate::{ actions::{deletion_vector::DeletionVectorDescriptor, visitors::visit_deletion_vector_at}, engine_data::{GetData, RowVisitor, TypedGetData as _}, @@ -104,6 +105,7 @@ pub type ScanCallback = fn( size: i64, stats: Option, dv_info: DvInfo, + transform: Option, partition_values: HashMap, ); @@ -138,12 +140,14 @@ pub type ScanCallback = fn( pub fn visit_scan_files( data: &dyn EngineData, selection_vector: &[bool], + transforms: &HashMap, context: T, callback: ScanCallback, ) -> DeltaResult { let mut visitor = ScanFileVisitor { callback, selection_vector, + transforms, context, }; visitor.visit_rows_of(data)?; @@ -154,6 +158,7 @@ pub fn visit_scan_files( struct ScanFileVisitor<'a, T> { callback: ScanCallback, selection_vector: &'a [bool], + transforms: &'a HashMap, context: T, } impl RowVisitor for ScanFileVisitor<'_, T> { @@ -201,6 +206,7 @@ impl RowVisitor for ScanFileVisitor<'_, T> { size, stats, dv_info, + self.transforms.get(&row_index).cloned(), // cheap Arc clone partition_values, ) } @@ -213,7 +219,10 @@ impl RowVisitor for ScanFileVisitor<'_, T> { mod tests { use std::collections::HashMap; - use crate::scan::test_utils::{add_batch_simple, run_with_validate_callback}; + use crate::{ + scan::test_utils::{add_batch_simple, run_with_validate_callback}, + ExpressionRef, + }; use super::{DvInfo, Stats}; @@ -228,6 +237,7 @@ mod tests { size: i64, stats: Option, dv_info: DvInfo, + transform: Option, part_vals: HashMap, ) { assert_eq!( @@ -242,6 +252,7 @@ mod tests { assert!(dv_info.deletion_vector.is_some()); let dv = dv_info.deletion_vector.unwrap(); assert_eq!(dv.unique_id(), "uvBn[lx{q8@P<9BNH/isA@1"); + assert!(transform.is_none()); assert_eq!(context.id, 2); } @@ -250,6 +261,8 @@ mod tests { let context = TestContext { id: 2 }; run_with_validate_callback( vec![add_batch_simple()], + None, // not testing schema + None, // not testing transform &[true, false], context, validate_visit, diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index ae49b70e2..790c804df 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -10,9 +10,9 @@ use delta_kernel::actions::deletion_vector::split_vector; use delta_kernel::engine::arrow_data::ArrowEngineData; use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; -use delta_kernel::expressions::{column_expr, BinaryOperator, Expression}; +use delta_kernel::expressions::{column_expr, BinaryOperator, Expression, ExpressionRef}; use delta_kernel::scan::state::{visit_scan_files, DvInfo, Stats}; -use delta_kernel::scan::{transform_to_logical, Scan}; +use delta_kernel::scan::Scan; use delta_kernel::schema::{DataType, Schema}; use delta_kernel::{Engine, FileMeta, Table}; use object_store::{memory::InMemory, path::Path, ObjectStore}; @@ -339,7 +339,7 @@ struct ScanFile { path: String, size: i64, dv_info: DvInfo, - partition_values: HashMap, + transform: Option, } fn scan_data_callback( @@ -348,13 +348,14 @@ fn scan_data_callback( size: i64, _stats: Option, dv_info: DvInfo, - partition_values: HashMap, + transform: Option, + _: HashMap, ) { batches.push(ScanFile { path: path.to_string(), size, dv_info, - partition_values, + transform, }); } @@ -369,8 +370,14 @@ fn read_with_scan_data( let scan_data = scan.scan_data(engine)?; let mut scan_files = vec![]; for data in scan_data { - let (data, vec) = data?; - scan_files = visit_scan_files(data.as_ref(), &vec, scan_files, scan_data_callback)?; + let (data, vec, transforms) = data?; + scan_files = visit_scan_files( + data.as_ref(), + &vec, + &transforms, + scan_files, + scan_data_callback, + )?; } let mut batches = vec![]; @@ -397,15 +404,20 @@ fn read_with_scan_data( for read_result in read_results { let read_result = read_result.unwrap(); let len = read_result.len(); - - // ask the kernel to transform the physical data into the correct logical form - let logical = transform_to_logical( - engine, - read_result, - &global_state, - &scan_file.partition_values, - ) - .unwrap(); + // to transform the physical data into the correct logical form + let logical = if let Some(ref transform) = scan_file.transform { + engine + .get_expression_handler() + .get_evaluator( + global_state.physical_schema.clone(), + transform.as_ref().clone(), // TODO: Maybe eval should take a ref + global_state.logical_schema.clone().into(), + ) + .evaluate(read_result.as_ref()) + .unwrap() + } else { + read_result + }; let record_batch = to_arrow(logical).unwrap(); let rest = split_vector(selection_vector.as_mut(), len, Some(true));