From b276d479918400105017db1f7f46dcb67b52206d Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Thu, 14 Dec 2023 19:32:46 -0500 Subject: [PATCH 01/31] Add test for DataFrame::write_table (#8531) * add test for DataFrame::write_table * remove duplicate let df=... * remove println! --- .../datasource/physical_plan/parquet/mod.rs | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 641b7bbb1596..847ea6505632 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -752,7 +752,7 @@ mod tests { use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::datasource::file_format::test_util::scan_format; - use crate::datasource::listing::{FileRange, PartitionedFile}; + use crate::datasource::listing::{FileRange, ListingOptions, PartitionedFile}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use crate::physical_plan::displayable; @@ -772,8 +772,8 @@ mod tests { }; use arrow_array::Date64Array; use chrono::{TimeZone, Utc}; - use datafusion_common::ScalarValue; use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{FileType, GetExt, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -1941,6 +1941,96 @@ mod tests { Ok(schema) } + #[tokio::test] + async fn write_table_results() -> Result<()> { + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + // let mut ctx = create_ctx(&tmp_dir, 4).await?; + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); + let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // register a local file system object store for /tmp directory + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + + // execute a simple query and write the results to parquet + let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + std::fs::create_dir(&out_dir).unwrap(); + let df = ctx.sql("SELECT c1, c2 FROM test").await?; + let schema: Schema = df.schema().into(); + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "my_table", + &out_dir, + listing_options, + Some(Arc::new(schema)), + None, + ) + .await + .unwrap(); + df.write_table("my_table", DataFrameWriteOptions::new()) + .await?; + + // create a new context and verify that the results were saved to a partitioned parquet file + let ctx = SessionContext::new(); + + // get write_id + let mut paths = fs::read_dir(&out_dir).unwrap(); + let path = paths.next(); + let name = path + .unwrap()? + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); + let write_id = parsed_id.to_owned(); + + // register each partition as well as the top level dir + ctx.register_parquet( + "part0", + &format!("{out_dir}/{write_id}_0.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) + .await?; + + let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?; + let allparts = ctx + .sql("SELECT c1, c2 FROM allparts") + .await? + .collect() + .await?; + + let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); + + assert_eq!(part0[0].schema(), allparts[0].schema()); + + assert_eq!(allparts_count, 40); + + Ok(()) + } + #[tokio::test] async fn write_parquet_results() -> Result<()> { // create partitioned input file and context @@ -1985,7 +2075,6 @@ mod tests { .to_str() .expect("Should be a str") .to_owned(); - println!("{name}"); let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); let write_id = parsed_id.to_owned(); From 28e7f60cf7d4fb87eeaf4e4c1102eb54bfb67426 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 15 Dec 2023 15:00:10 +0300 Subject: [PATCH 02/31] Generate empty column at placeholder exec (#8553) --- datafusion/physical-plan/src/placeholder_row.rs | 7 ++++--- datafusion/sqllogictest/test_files/window.slt | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 94f32788530b..3ab3de62f37a 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -27,6 +27,7 @@ use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning use arrow::array::{ArrayRef, NullArray}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -59,9 +60,7 @@ impl PlaceholderRowExec { fn data(&self) -> Result> { Ok({ let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( + vec![RecordBatch::try_new_with_options( Arc::new(Schema::new( (0..n_field) .map(|i| { @@ -75,6 +74,8 @@ impl PlaceholderRowExec { ret }) .collect(), + // Even if column number is empty we can generate single row. + &RecordBatchOptions::new().with_row_count(Some(1)), )?] }) } diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7b628f9b6f14..6198209aaac5 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3793,3 +3793,9 @@ select a, ---- 1 1 2 1 + +query I +select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x +---- +1 +1 From f54eeea08eafc1c434d67ede4f39d5c2fb14dfdb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 15 Dec 2023 07:18:50 -0500 Subject: [PATCH 03/31] Minor: Remove now dead SUPPORTED_STRUCT_TYPES (#8480) --- datafusion/expr/src/lib.rs | 1 - datafusion/expr/src/struct_expressions.rs | 35 ----------------------- 2 files changed, 36 deletions(-) delete mode 100644 datafusion/expr/src/struct_expressions.rs diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 6172d17365ad..48532e13dcd7 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -49,7 +49,6 @@ pub mod field_util; pub mod function; pub mod interval_arithmetic; pub mod logical_plan; -pub mod struct_expressions; pub mod tree_node; pub mod type_coercion; pub mod utils; diff --git a/datafusion/expr/src/struct_expressions.rs b/datafusion/expr/src/struct_expressions.rs deleted file mode 100644 index bbfcac0e2396..000000000000 --- a/datafusion/expr/src/struct_expressions.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; - -/// Currently supported types by the struct function. -pub static SUPPORTED_STRUCT_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; From 82235aeaec0eb096b762181ce323f4e39f8250a9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:14:54 +0300 Subject: [PATCH 04/31] [MINOR]: Add getter methods to first and last value (#8555) --- .../physical-expr/src/aggregate/first_last.rs | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 0dc27dede8b6..5e2012bdbb67 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -61,6 +61,31 @@ impl FirstValue { ordering_req, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } } impl AggregateExpr for FirstValue { @@ -285,6 +310,31 @@ impl LastValue { ordering_req, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } } impl AggregateExpr for LastValue { From bf0073c03ace1e4212f5895c529592d9925bf28d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:10:12 +0200 Subject: [PATCH 05/31] [MINOR]: Some code changes and a new empty batch guard for SHJ (#8557) * minor changes * Fix imports --------- Co-authored-by: Mehmet Ozan Kabak --- .../src/joins/stream_join_utils.rs | 83 ++++++++++++++++++- .../src/joins/symmetric_hash_join.rs | 64 +------------- 2 files changed, 83 insertions(+), 64 deletions(-) diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 5083f96b01fb..2f74bd1c4bb2 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,8 +23,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::handle_async_state; use crate::joins::utils::{JoinFilter, JoinHashMapType}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use crate::{handle_async_state, metrics}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -824,6 +825,10 @@ pub trait EagerJoinStream { ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.set_state(EagerJoinStreamState::PullLeft); self.process_batch_from_right(batch) } @@ -849,6 +854,9 @@ pub trait EagerJoinStream { ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } self.set_state(EagerJoinStreamState::PullRight); self.process_batch_from_left(batch) } @@ -874,7 +882,12 @@ pub trait EagerJoinStream { &mut self, ) -> Result>> { match self.left_stream().next().await { - Some(Ok(batch)) => self.process_batch_after_right_end(batch), + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.process_batch_after_right_end(batch) + } Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::BothExhausted { @@ -899,7 +912,12 @@ pub trait EagerJoinStream { &mut self, ) -> Result>> { match self.right_stream().next().await { - Some(Ok(batch)) => self.process_batch_after_left_end(batch), + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.process_batch_after_left_end(batch) + } Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::BothExhausted { @@ -1020,6 +1038,65 @@ pub trait EagerJoinStream { fn state(&mut self) -> EagerJoinStreamState; } +#[derive(Debug)] +pub struct StreamJoinSideMetrics { + /// Number of batches consumed by this operator + pub(crate) input_batches: metrics::Count, + /// Number of rows consumed by this operator + pub(crate) input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +pub struct StreamJoinMetrics { + /// Number of left batches/rows consumed by this operator + pub(crate) left: StreamJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + pub(crate) right: StreamJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, +} + +impl StreamJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + #[cfg(test)] pub mod tests { use std::sync::Arc; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 95f15877b960..00a7f23ebae7 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -37,7 +37,8 @@ use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, record_visited_indices, EagerJoinStream, - EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinStateResult, + EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, + StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, @@ -47,7 +48,7 @@ use crate::joins::utils::{ use crate::{ expressions::{Column, PhysicalSortExpr}, joins::StreamJoinPartitionMode, - metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -184,65 +185,6 @@ pub struct SymmetricHashJoinExec { mode: StreamJoinPartitionMode, } -#[derive(Debug)] -pub struct StreamJoinSideMetrics { - /// Number of batches consumed by this operator - pub(crate) input_batches: metrics::Count, - /// Number of rows consumed by this operator - pub(crate) input_rows: metrics::Count, -} - -/// Metrics for HashJoinExec -#[derive(Debug)] -pub struct StreamJoinMetrics { - /// Number of left batches/rows consumed by this operator - pub(crate) left: StreamJoinSideMetrics, - /// Number of right batches/rows consumed by this operator - pub(crate) right: StreamJoinSideMetrics, - /// Memory used by sides in bytes - pub(crate) stream_memory_usage: metrics::Gauge, - /// Number of batches produced by this operator - pub(crate) output_batches: metrics::Count, - /// Number of rows produced by this operator - pub(crate) output_rows: metrics::Count, -} - -impl StreamJoinMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = StreamJoinSideMetrics { - input_batches, - input_rows, - }; - - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = StreamJoinSideMetrics { - input_batches, - input_rows, - }; - - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); - - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - - Self { - left, - right, - output_batches, - stream_memory_usage, - output_rows, - } - } -} - impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error From b7fde3ce7040c0569295c8b90d5d4f267296878e Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 15 Dec 2023 11:14:43 -0800 Subject: [PATCH 06/31] docs: update udf docs for udtf (#8546) * docs: update udf docs for udtf * docs: update header * style: run prettier * fix: fix stale comment * docs: expand on use cases --- datafusion-examples/examples/simple_udtf.rs | 1 + docs/source/library-user-guide/adding-udfs.md | 110 +++++++++++++++++- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index e120c5e7bf8e..f1d763ba6e41 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -125,6 +125,7 @@ impl TableProvider for LocalCsvTable { )?)) } } + struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1e710bc321a2..11cf52eb3fcf 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -17,17 +17,18 @@ under the License. --> -# Adding User Defined Functions: Scalar/Window/Aggregate +# Adding User Defined Functions: Scalar/Window/Aggregate/Table Functions User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs) | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs) | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs) | +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -432,3 +433,100 @@ Then, we can query like below: ```rust let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; ``` + +## Adding a User-Defined Table Function + +A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. + +Because we're returning a `TableProvider`, in this example we'll use the `MemTable` data source to represent a table. This is a simple struct that holds a set of RecordBatches in memory and treats them as a table. In your case, this would be replaced with your own struct that implements `TableProvider`. + +While this is a simple example for illustrative purposes, UDTFs have a lot of potential use cases. And can be particularly useful for reading data from external sources and interactive analysis. For example, see the [example][4] for a working example that reads from a CSV file. As another example, you could use the built-in UDTF `parquet_metadata` in the CLI to read the metadata from a Parquet file. + +```console +❯ select filename, row_group_id, row_group_num_rows, row_group_bytes, stats_min, stats_max from parquet_metadata('./benchmarks/data/hits.parquet') where column_id = 17 limit 10; ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| filename | row_group_id | row_group_num_rows | row_group_bytes | stats_min | stats_max | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| ./benchmarks/data/hits.parquet | 0 | 450560 | 188921521 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 1 | 612174 | 210338885 | 0 | 109827 | +| ./benchmarks/data/hits.parquet | 2 | 344064 | 161242466 | 0 | 122484 | +| ./benchmarks/data/hits.parquet | 3 | 606208 | 235549898 | 0 | 121073 | +| ./benchmarks/data/hits.parquet | 4 | 335872 | 137103898 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 5 | 311296 | 145453612 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 6 | 303104 | 138833963 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 7 | 303104 | 191140113 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 8 | 573440 | 208038598 | 0 | 95823 | +| ./benchmarks/data/hits.parquet | 9 | 344064 | 147838157 | 0 | 73256 | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +``` + +### Writing the UDTF + +The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. + +In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. + +```rust +use datafusion::common::plan_err; +use datafusion::datasource::function::TableFunctionImpl; +// Other imports here + +/// A table function that returns a table provider with the value as a single column +#[derive(Default)] +pub struct EchoFunction {} + +impl TableFunctionImpl for EchoFunction { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + return plan_err!("First argument must be an integer"); + }; + + // Create the schema for the table + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a single RecordBatch with the value as a single column + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![*value]))], + )?; + + // Create a MemTable plan that returns the RecordBatch + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + + Ok(Arc::new(provider)) + } +} +``` + +### Registering and Using the UDTF + +With the UDTF implemented, you can register it with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udtf("echo", Arc::new(EchoFunction::default())); +``` + +And if all goes well, you can use it in your query: + +```rust +use datafusion::arrow::util::pretty; + +let df = ctx.sql("SELECT * FROM echo(1)").await?; + +let results = df.collect().await?; +pretty::print_batches(&results)?; +// +---+ +// | a | +// +---+ +// | 1 | +// +---+ +``` + +[1]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[2]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[3]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[4]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs From b71bec0fd7d17eeab5e8002842322082cd187a25 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 16 Dec 2023 03:18:08 +0800 Subject: [PATCH 07/31] feat: implement Unary Expr in substrait (#8534) Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 74 ++++----- .../substrait/src/logical_plan/producer.rs | 141 ++++++++++++------ .../tests/cases/roundtrip_logical_plan.rs | 40 +++++ 3 files changed, 169 insertions(+), 86 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f6b556fc6448..f64dc764a7ed 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1253,7 +1253,9 @@ struct BuiltinExprBuilder { impl BuiltinExprBuilder { pub fn try_from_name(name: &str) -> Option { match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -1267,14 +1269,11 @@ impl BuiltinExprBuilder { extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "not" => Self::build_not_expr(f, input_schema, extensions).await, "like" => Self::build_like_expr(false, f, input_schema, extensions).await, "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, - "is_null" => { - Self::build_is_null_expr(false, f, input_schema, extensions).await - } - "is_not_null" => { - Self::build_is_null_expr(true, f, input_schema, extensions).await + "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" + | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1282,22 +1281,39 @@ impl BuiltinExprBuilder { } } - async fn build_not_expr( + async fn build_unary_expr( + fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { if f.arguments.len() != 1 { - return not_impl_err!("Expect one argument for `NOT` expr"); + return substrait_err!("Expect one argument for {fn_name} expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `NOT` expr"); + return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(Arc::new(expr)) } async fn build_like_expr( @@ -1308,25 +1324,25 @@ impl BuiltinExprBuilder { ) -> Result> { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + return substrait_err!("Expect three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = from_substrait_rex(escape_char_substrait, input_schema, extensions) @@ -1347,30 +1363,4 @@ impl BuiltinExprBuilder { case_insensitive, }))) } - - async fn build_is_null_expr( - is_not: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { - let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" }; - let arg = f.arguments.first().ok_or_else(|| { - substrait_datafusion_err!("expect one argument for `{fn_name}` expr") - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - if is_not { - Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) - } else { - Ok(Arc::new(Expr::IsNull(Box::new(expr)))) - } - } - _ => substrait_err!("Invalid arguments for `{fn_name}` expression"), - } - } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c5f1278be6e0..81498964eb61 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1083,50 +1083,76 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - Expr::IsNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::IsNotNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_not_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } + Expr::Not(arg) => to_substrait_unary_scalar_fn( + "not", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + "is_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + "is_not_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + "is_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + "is_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + "is_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + "is_not_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + "is_not_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + "is_not_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::Negative(arg) => to_substrait_unary_scalar_fn( + "negative", + arg, + schema, + col_ref_offset, + extension_info, + ), _ => { not_impl_err!("Unsupported expression: {expr:?}") } @@ -1591,6 +1617,33 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = _register_function(fn_name.to_string(), extension_info); + let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_nullability = r#type::Nullability::Nullable as i32; match v { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 691fba864449..91d5a9469627 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -483,6 +483,46 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_not() -> Result<()> { + roundtrip("SELECT * FROM data WHERE NOT d").await +} + +#[tokio::test] +async fn roundtrip_negative() -> Result<()> { + roundtrip("SELECT * FROM data WHERE -a = 1").await +} + +#[tokio::test] +async fn roundtrip_is_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_not_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_not_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await +} + +#[tokio::test] +async fn roundtrip_is_not_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await +} + #[tokio::test] async fn roundtrip_union() -> Result<()> { roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await From 0fcd077c67b07092c94acae86ffaa97dfb54789a Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 16 Dec 2023 20:17:32 +0800 Subject: [PATCH 08/31] Fix `compute_record_batch_statistics` wrong with `projection` (#8489) * Minor: Improve the document format of JoinHashMap * fix `compute_record_batch_statistics` wrong with `projection` * fix test * fix test --- datafusion/physical-plan/src/common.rs | 38 +++++++++++------ .../sqllogictest/test_files/groupby.slt | 21 +++++----- datafusion/sqllogictest/test_files/joins.slt | 42 +++++++++---------- 3 files changed, 57 insertions(+), 44 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 649f3a31aa7e..e83dc2525b9f 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -30,6 +30,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; @@ -139,17 +140,22 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .map(|b| b.get_array_memory_size()) - .sum(); - let projection = match projection { Some(p) => p, None => (0..schema.fields().len()).collect(), }; + let total_byte_size = batches + .iter() + .flatten() + .map(|b| { + projection + .iter() + .map(|index| b.column(*index).get_array_memory_size()) + .sum::() + }) + .sum(); + let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; for partition in batches.iter() { @@ -388,6 +394,7 @@ mod tests { datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; + use arrow_array::UInt64Array; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, Column}; @@ -685,20 +692,30 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, false), + Field::new("u64", DataType::UInt64, false), ])); let batch = RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![1., 2., 3.])), Arc::new(Float64Array::from(vec![9., 8., 7.])), + Arc::new(UInt64Array::from(vec![4, 5, 6])), ], )?; + + // just select f32,f64 + let select_projection = Some(vec![0, 1]); + let byte_size = batch + .project(&select_projection.clone().unwrap()) + .unwrap() + .get_array_memory_size(); + let actual = - compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); + compute_record_batch_statistics(&[vec![batch]], &schema, select_projection); - let mut expected = Statistics { + let expected = Statistics { num_rows: Precision::Exact(3), - total_byte_size: Precision::Exact(464), // this might change a bit if the way we compute the size changes + total_byte_size: Precision::Exact(byte_size), column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Absent, @@ -715,9 +732,6 @@ mod tests { ], }; - // Prevent test flakiness due to undefined / changing implementation details - expected.total_byte_size = actual.total_byte_size.clone(); - assert_eq!(actual, expected); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index b915c439059b..44d30ba0b34c 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2021,14 +2021,15 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) --------------SortExec: expr=[col0@3 ASC NULLS LAST] -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] ---------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[3] ---------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2709,9 +2710,9 @@ SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] ------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] +--------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] ----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 +------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 67e3750113da..1ad17fbb8c91 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1569,15 +1569,13 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] --CoalesceBatchesExec: target_batch_size=2 -----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] -------CoalescePartitionsExec ---------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] -----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] -------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------MemoryExec: partitions=1, partition_sizes=[1] +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1595,18 +1593,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] --CoalesceBatchesExec: target_batch_size=2 -----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] ------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 ----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] -------CoalesceBatchesExec: target_batch_size=2 ---------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] # Right side expr key inner join @@ -2821,13 +2819,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2862,13 +2860,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2924,7 +2922,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2960,7 +2958,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] From 1f4c14c7b942de81c518b31be9a16dfb07e5237e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 17 Dec 2023 19:39:27 +0800 Subject: [PATCH 09/31] cleanup parquet flag (#8563) Signed-off-by: jayzhan211 --- datafusion/common/src/file_options/file_type.rs | 2 +- datafusion/common/src/file_options/mod.rs | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index b1d61b1a2567..97362bdad3cc 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -103,13 +103,13 @@ impl FromStr for FileType { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use crate::error::DataFusionError; use crate::file_options::FileType; use std::str::FromStr; #[test] - #[cfg(feature = "parquet")] fn from_str() { for (ext, file_type) in [ ("csv", FileType::CSV), diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index f0e49dd85597..1d661b17eb1c 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -296,10 +296,10 @@ impl Display for FileTypeWriterOptions { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use std::collections::HashMap; - #[cfg(feature = "parquet")] use parquet::{ basic::{Compression, Encoding, ZstdLevel}, file::properties::{EnabledStatistics, WriterVersion}, @@ -314,11 +314,9 @@ mod tests { use crate::Result; - #[cfg(feature = "parquet")] use super::{parquet_writer::ParquetWriterOptions, StatementOptions}; #[test] - #[cfg(feature = "parquet")] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("max_row_group_size".to_owned(), "123".to_owned()); @@ -389,7 +387,6 @@ mod tests { } #[test] - #[cfg(feature = "parquet")] fn test_writeroptions_parquet_column_specific() -> Result<()> { let mut option_map: HashMap = HashMap::new(); @@ -511,7 +508,6 @@ mod tests { #[test] // for StatementOptions - #[cfg(feature = "parquet")] fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("header".to_owned(), "true".to_owned()); @@ -540,7 +536,6 @@ mod tests { #[test] // for StatementOptions - #[cfg(feature = "parquet")] fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("compression".to_owned(), "gzip".to_owned()); From b59ddf64fc77bbd37aa761c856d47ebc473ea2e2 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 17 Dec 2023 19:41:06 +0800 Subject: [PATCH 10/31] Minor: move some invariants out of the loop (#8564) --- datafusion/optimizer/src/push_down_filter.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c090fb849a82..4bea17500acc 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -559,6 +559,15 @@ fn push_down_join( let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { is_inner_join = true; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { + (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), + _ => None, + }) + .collect::>(); // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -583,16 +592,6 @@ fn push_down_join( Err(e) => return Some(Err(e)), }; - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, - }) - .collect::>(); - for col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == l { From 0f83ffc448a4d7fb4297148f653e267a847d769a Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 17 Dec 2023 19:55:29 +0800 Subject: [PATCH 11/31] feat: implement Repartition plan in substrait (#8526) * feat: implement Repartition plan in substrait Signed-off-by: Ruihang Xia * use substrait_err macro Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 96 ++++++++++++++----- .../substrait/src/logical_plan/producer.rs | 81 +++++++++++++++- .../tests/cases/roundtrip_logical_plan.rs | 36 ++++++- 3 files changed, 185 insertions(+), 28 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f64dc764a7ed..b7fee96bba1c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,8 +27,8 @@ use datafusion::logical_expr::{ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -38,7 +38,8 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::{Literal, ScalarFunction}; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -550,6 +551,45 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } + Some(RelType::Exchange(exchange)) => { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = + from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash( + partition_columns, + exchange.partition_count as usize, + ) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) + } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } @@ -725,27 +765,9 @@ pub async fn from_substrait_rex( negated: false, }))) } - Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => { - let column = - input_schema.field(x.field as usize).qualified_column(); - Ok(Arc::new(Expr::Column(Column { - relation: column.relation, - name: column.name, - }))) - } - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - }, + Some(RexType::Selection(field_ref)) => Ok(Arc::new( + from_substrait_field_reference(field_ref, input_schema)?, + )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -1245,6 +1267,32 @@ fn from_substrait_null(null_type: &Type) -> Result { } } +fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} + /// Build [`Expr`] from its name and required inputs. struct BuiltinExprBuilder { expr_name: String, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 81498964eb61..50f872544298 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; +use datafusion::logical_expr::{ + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -28,8 +30,8 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::DFSchemaRef; use datafusion::common::{exec_err, internal_err, not_impl_err}; +use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ @@ -39,8 +41,9 @@ use datafusion::logical_expr::expr::{ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::CrossRel; +use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -410,6 +413,53 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } + LogicalPlan::Repartition(repartition) => { + let input = + to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| { + try_to_substrait_field_reference( + e, + repartition.input.schema(), + ) + }) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) + } LogicalPlan::Extension(extension_plan) => { let extension_bytes = ctx .state() @@ -1804,6 +1854,31 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { } } +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + fn substrait_sort_field( expr: &Expr, schema: &DFSchemaRef, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 91d5a9469627..47eb5a8f73f5 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ - Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, + Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -738,6 +738,40 @@ async fn roundtrip_aggregate_udf() -> Result<()> { roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await } +#[tokio::test] +async fn roundtrip_repartition_roundrobin() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::RoundRobinBatch(8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_repartition_hash() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { From 2e16c7519cb4a21d54975e56e2127039a3a6fd04 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 17 Dec 2023 07:12:10 -0500 Subject: [PATCH 12/31] Fix sort order aware file group parallelization (#8517) * Minor: Extract file group repartitioning and tests into `FileGroupRepartitioner` * Implement sort order aware redistribution --- datafusion/core/src/datasource/listing/mod.rs | 16 +- .../core/src/datasource/physical_plan/csv.rs | 14 +- .../datasource/physical_plan/file_groups.rs | 826 ++++++++++++++++++ .../physical_plan/file_scan_config.rs | 85 +- .../core/src/datasource/physical_plan/mod.rs | 344 +------- .../datasource/physical_plan/parquet/mod.rs | 16 +- .../enforce_distribution.rs | 61 +- .../test_files/repartition_scan.slt | 2 +- 8 files changed, 918 insertions(+), 446 deletions(-) create mode 100644 datafusion/core/src/datasource/physical_plan/file_groups.rs diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 5e5b96f6ba8c..e7583501f9d9 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -40,7 +40,7 @@ pub type PartitionedFileStream = /// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" /// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping /// sections of a Parquet file in parallel. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct FileRange { /// Range start pub start: i64, @@ -70,13 +70,12 @@ pub struct PartitionedFile { /// An optional field for user defined per object metadata pub extensions: Option>, } - impl PartitionedFile { /// Create a simple file without metadata or partition - pub fn new(path: String, size: u64) -> Self { + pub fn new(path: impl Into, size: u64) -> Self { Self { object_meta: ObjectMeta { - location: Path::from(path), + location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, @@ -99,9 +98,10 @@ impl PartitionedFile { version: None, }, partition_values: vec![], - range: Some(FileRange { start, end }), + range: None, extensions: None, } + .with_range(start, end) } /// Return a file reference from the given path @@ -114,6 +114,12 @@ impl PartitionedFile { pub fn path(&self) -> &Path { &self.object_meta.location } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 816a82543bab..0eca37da139d 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -23,7 +23,7 @@ use std::ops::Range; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{FileGroupPartitioner, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -177,7 +177,7 @@ impl ExecutionPlan for CsvExec { } /// Redistribute files across partitions according to their size - /// See comments on `repartition_file_groups()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. /// /// Return `None` if can't get repartitioned(empty/compressed file). fn repartitioned( @@ -191,11 +191,11 @@ impl ExecutionPlan for CsvExec { return Ok(None); } - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&self.base_config.file_groups); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { let mut new_plan = self.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs new file mode 100644 index 000000000000..6456bd5c7276 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for managing groups of [`PartitionedFile`]s in DataFusion + +use crate::datasource::listing::{FileRange, PartitionedFile}; +use itertools::Itertools; +use std::cmp::min; +use std::collections::BinaryHeap; +use std::iter::repeat_with; + +/// Repartition input files into `target_partitions` partitions, if total file size exceed +/// `repartition_file_min_size` +/// +/// This partitions evenly by file byte range, and does not have any knowledge +/// of how data is laid out in specific files. The specific `FileOpener` are +/// responsible for the actual partitioning on specific data source type. (e.g. +/// the `CsvOpener` will read lines overlap with byte range as well as +/// handle boundaries to ensure all lines will be read exactly once) +/// +/// # Example +/// +/// For example, if there are two files `A` and `B` that we wish to read with 4 +/// partitions (with 4 threads) they will be divided as follows: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// │ File A (7MB) │ ────────▶ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// └─────────────────┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ File B (1MB) │ ┌─────────────────┐ +/// │ │ │ │ File A │ │ +/// └─────────────────┘ │ Range: 6-7MB │ +/// │ └─────────────────┘ │ +/// ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +/// +/// # Maintaining Order +/// +/// Within each group files are read sequentially. Thus, if the overall order of +/// tuples must be preserved, multiple files can not be mixed in the same group. +/// +/// In this case, the code will split the largest files evenly into any +/// available empty groups, but the overall distribution may not not be as even +/// as as even as if the order did not need to be preserved. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ File A (6MB) │ ────────▶ │ │ +/// │ (ordered) │ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// └─────────────────┘ │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ File B (1MB) │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ (ordered) │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// └─────────────────┘ ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct FileGroupPartitioner { + /// how many partitions should be created + target_partitions: usize, + /// the minimum size for a file to be repartitioned. + repartition_file_min_size: usize, + /// if the order when reading the files must be preserved + preserve_order_within_groups: bool, +} + +impl Default for FileGroupPartitioner { + fn default() -> Self { + Self::new() + } +} + +impl FileGroupPartitioner { + /// Creates a new [`FileGroupPartitioner`] with default values: + /// 1. `target_partitions = 1` + /// 2. `repartition_file_min_size = 10MB` + /// 3. `preserve_order_within_groups = false` + pub fn new() -> Self { + Self { + target_partitions: 1, + repartition_file_min_size: 10 * 1024 * 1024, + preserve_order_within_groups: false, + } + } + + /// Set the target partitions + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set the minimum size at which to repartition a file + pub fn with_repartition_file_min_size( + mut self, + repartition_file_min_size: usize, + ) -> Self { + self.repartition_file_min_size = repartition_file_min_size; + self + } + + /// Set whether the order of tuples within a file must be preserved + pub fn with_preserve_order_within_groups( + mut self, + preserve_order_within_groups: bool, + ) -> Self { + self.preserve_order_within_groups = preserve_order_within_groups; + self + } + + /// Repartition input files according to the settings on this [`FileGroupPartitioner`]. + /// + /// If no repartitioning is needed or possible, return `None`. + pub fn repartition_file_groups( + &self, + file_groups: &[Vec], + ) -> Option>> { + if file_groups.is_empty() { + return None; + } + + // Perform redistribution only in case all files should be read from beginning to end + let has_ranges = file_groups.iter().flatten().any(|f| f.range.is_some()); + if has_ranges { + return None; + } + + // special case when order must be preserved + if self.preserve_order_within_groups { + self.repartition_preserving_order(file_groups) + } else { + self.repartition_evenly_by_size(file_groups) + } + } + + /// Evenly repartition files across partitions by size, ignoring any + /// existing grouping / ordering + fn repartition_evenly_by_size( + &self, + file_groups: &[Vec], + ) -> Option>> { + let target_partitions = self.target_partitions; + let repartition_file_min_size = self.repartition_file_min_size; + let flattened_files = file_groups.iter().flatten().collect::>(); + + let total_size = flattened_files + .iter() + .map(|f| f.object_meta.size as i64) + .sum::(); + if total_size < (repartition_file_min_size as i64) || total_size == 0 { + return None; + } + + let target_partition_size = + (total_size as usize + (target_partitions) - 1) / (target_partitions); + + let current_partition_index: usize = 0; + let current_partition_size: usize = 0; + + // Partition byte range evenly for all `PartitionedFile`s + let repartitioned_files = flattened_files + .into_iter() + .scan( + (current_partition_index, current_partition_size), + |state, source_file| { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) >= target_partition_size { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) + }, + ) + .flatten() + .group_by(|(partition_idx, _)| *partition_idx) + .into_iter() + .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) + .collect_vec(); + + Some(repartitioned_files) + } + + /// Redistribute file groups across size preserving order + fn repartition_preserving_order( + &self, + file_groups: &[Vec], + ) -> Option>> { + // Can't repartition and preserve order if there are more groups + // than partitions + if file_groups.len() >= self.target_partitions { + return None; + } + let num_new_groups = self.target_partitions - file_groups.len(); + + // If there is only a single file + if file_groups.len() == 1 && file_groups[0].len() == 1 { + return self.repartition_evenly_by_size(file_groups); + } + + // Find which files could be split (single file groups) + let mut heap: BinaryHeap<_> = file_groups + .iter() + .enumerate() + .filter_map(|(group_index, group)| { + // ignore groups that do not have exactly 1 file + if group.len() == 1 { + Some(ToRepartition { + source_index: group_index, + file_size: group[0].object_meta.size, + new_groups: vec![group_index], + }) + } else { + None + } + }) + .collect(); + + // No files can be redistributed + if heap.is_empty() { + return None; + } + + // Add new empty groups to which we will redistribute ranges of existing files + let mut file_groups: Vec<_> = file_groups + .iter() + .cloned() + .chain(repeat_with(Vec::new).take(num_new_groups)) + .collect(); + + // Divide up empty groups + for (group_index, group) in file_groups.iter().enumerate() { + if !group.is_empty() { + continue; + } + // Pick the file that has the largest ranges to read so far + let mut largest_group = heap.pop().unwrap(); + largest_group.new_groups.push(group_index); + heap.push(largest_group); + } + + // Distribute files to their newly assigned groups + while let Some(to_repartition) = heap.pop() { + let range_size = to_repartition.range_size() as i64; + let ToRepartition { + source_index, + file_size, + new_groups, + } = to_repartition; + assert_eq!(file_groups[source_index].len(), 1); + let original_file = file_groups[source_index].pop().unwrap(); + + let last_group = new_groups.len() - 1; + let mut range_start: i64 = 0; + let mut range_end: i64 = range_size; + for (i, group_index) in new_groups.into_iter().enumerate() { + let target_group = &mut file_groups[group_index]; + assert!(target_group.is_empty()); + + // adjust last range to include the entire file + if i == last_group { + range_end = file_size as i64; + } + target_group + .push(original_file.clone().with_range(range_start, range_end)); + range_start = range_end; + range_end += range_size; + } + } + + Some(file_groups) + } +} + +/// Tracks how a individual file will be repartitioned +#[derive(Debug, Clone, PartialEq, Eq)] +struct ToRepartition { + /// the index from which the original file will be taken + source_index: usize, + /// the size of the original file + file_size: usize, + /// indexes of which group(s) will this be distributed to (including `source_index`) + new_groups: Vec, +} + +impl ToRepartition { + // how big will each file range be when this file is read in its new groups? + fn range_size(&self) -> usize { + self.file_size / self.new_groups.len() + } +} + +impl PartialOrd for ToRepartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Order based on individual range +impl Ord for ToRepartition { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.range_size().cmp(&other.range_size()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Empty file won't get partitioned + #[test] + fn repartition_empty_file_only() { + let partitioned_file_empty = pfile("empty", 0); + let file_group = vec![vec![partitioned_file_empty.clone()]]; + + let partitioned_files = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(0) + .repartition_file_groups(&file_group); + + assert_partitioned_files(None, partitioned_files); + } + + /// Repartition when there is a empty file in file groups + #[test] + fn repartition_empty_files() { + let pfile_a = pfile("a", 10); + let pfile_b = pfile("b", 10); + let pfile_empty = pfile("empty", 0); + + let empty_first = vec![ + vec![pfile_empty.clone()], + vec![pfile_a.clone()], + vec![pfile_b.clone()], + ]; + let empty_middle = vec![ + vec![pfile_a.clone()], + vec![pfile_empty.clone()], + vec![pfile_b.clone()], + ]; + let empty_last = vec![vec![pfile_a], vec![pfile_b], vec![pfile_empty]]; + + // Repartition file groups into x partitions + let expected_2 = vec![ + vec![pfile("a", 10).with_range(0, 10)], + vec![pfile("b", 10).with_range(0, 10)], + ]; + let expected_3 = vec![ + vec![pfile("a", 10).with_range(0, 7)], + vec![ + pfile("a", 10).with_range(7, 10), + pfile("b", 10).with_range(0, 4), + ], + vec![pfile("b", 10).with_range(4, 10)], + ]; + + let file_groups_tests = [empty_first, empty_middle, empty_last]; + + for fg in file_groups_tests { + let all_expected = [(2, expected_2.clone()), (3, expected_3.clone())]; + for (n_partition, expected) in all_expected { + let actual = FileGroupPartitioner::new() + .with_target_partitions(n_partition) + .with_repartition_file_min_size(10) + .repartition_file_groups(&fg); + + assert_partitioned_files(Some(expected), actual); + } + } + } + + #[test] + fn repartition_single_file() { + // Single file, single partition into multiple partitions + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 123).with_range(0, 31)], + vec![pfile("a", 123).with_range(31, 62)], + vec![pfile("a", 123).with_range(62, 93)], + vec![pfile("a", 123).with_range(93, 123)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_too_much_partitions() { + // Single file, single partition into 96 partitions + let partitioned_file = pfile("a", 8); + let single_partition = vec![vec![partitioned_file]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(96) + .with_repartition_file_min_size(5) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 8).with_range(0, 1)], + vec![pfile("a", 8).with_range(1, 2)], + vec![pfile("a", 8).with_range(2, 3)], + vec![pfile("a", 8).with_range(3, 4)], + vec![pfile("a", 8).with_range(4, 5)], + vec![pfile("a", 8).with_range(5, 6)], + vec![pfile("a", 8).with_range(6, 7)], + vec![pfile("a", 8).with_range(7, 8)], + ]); + + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_multiple_partitions() { + // Multiple files in single partition after redistribution + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 40).with_range(0, 34)], + vec![ + pfile("a", 40).with_range(34, 40), + pfile("b", 60).with_range(0, 28), + ], + vec![pfile("b", 60).with_range(28, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_same_num_partitions() { + // "Rebalance" files across partitions + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![ + pfile("a", 40).with_range(0, 40), + pfile("b", 60).with_range(0, 10), + ], + vec![pfile("b", 60).with_range(10, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_no_action_ranges() { + // No action due to Some(range) in second file + let source_partitions = vec![ + vec![pfile("a", 123)], + vec![pfile("b", 144).with_range(1, 50)], + ]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_min_size() { + // No action due to target_partition_size + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_zero_files() { + // No action due to no files + let empty_partition = vec![]; + + let partitioner = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500); + + assert_partitioned_files(None, repartition_test(partitioner, empty_partition)) + } + + #[test] + fn repartition_ordered_no_action_too_few_partitions() { + // No action as there are no new groups to redistribute to + let input_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 200)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&input_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_no_action_file_too_small() { + // No action as there are no new groups to redistribute to + let single_partition = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + // file is too small to repartition + .with_repartition_file_min_size(1000) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_one_large_file() { + // "Rebalance" the single large file across partitions + let source_partitions = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 100).with_range(0, 34)], + vec![pfile("a", 100).with_range(34, 68)], + vec![pfile("a", 100).with_range(68, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 30)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + vec![pfile("a", 100).with_range(0, 33)], + // only b in this group (can't do this) + vec![pfile("b", 30).with_range(0, 30)], + // second third of "a" + vec![pfile("a", 100).with_range(33, 66)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_files() { + // "Rebalance" two large files across empty partitions, but can't mix them + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_one_small_files() { + // "Rebalance" two large files and one small file across empty partitions + let source_partitions = vec![ + vec![pfile("a", 100)], + vec![pfile("b", 100)], + vec![pfile("c", 30)], + ]; + + let partitioner = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_repartition_file_min_size(10); + + // with 4 partitions, can only split the first large file "a" + let actual = partitioner + .with_target_partitions(4) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // All of "b" + vec![pfile("b", 100).with_range(0, 100)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + + // With 5 partitions, we can split both "a" and "b", but they can't be intermixed + let actual = partitioner + .with_target_partitions(5) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_existing_empty() { + // "Rebalance" files using existing empty partition + let source_partitions = + vec![vec![pfile("a", 100)], vec![], vec![pfile("b", 40)], vec![]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(5) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // Scan of "a" across three groups + vec![pfile("a", 100).with_range(0, 33)], + vec![pfile("a", 100).with_range(33, 66)], + // scan first half of "b" + vec![pfile("b", 40).with_range(0, 20)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + // second half of "b" + vec![pfile("b", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + #[test] + fn repartition_ordered_existing_group_multiple_files() { + // groups with multiple files in a group can not be changed, but can divide others + let source_partitions = vec![ + // two files in an existing partition + vec![pfile("a", 100), pfile("b", 100)], + vec![pfile("c", 40)], + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // don't try and rearrange files in the existing partition + // assuming that the caller had a good reason to put them that way. + // (it is technically possible to split off ranges from the files if desired) + vec![pfile("a", 100), pfile("b", 100)], + // first half of "c" + vec![pfile("c", 40).with_range(0, 20)], + // second half of "c" + vec![pfile("c", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + + /// Asserts that the two groups of `ParititonedFile` are the same + /// (PartitionedFile doesn't implement PartialEq) + fn assert_partitioned_files( + expected: Option>>, + actual: Option>>, + ) { + match (expected, actual) { + (None, None) => {} + (Some(_), None) => panic!("Expected Some, got None"), + (None, Some(_)) => panic!("Expected None, got Some"), + (Some(expected), Some(actual)) => { + let expected_string = format!("{:#?}", expected); + let actual_string = format!("{:#?}", actual); + assert_eq!(expected_string, actual_string); + } + } + } + + /// returns a partitioned file with the specified path and size + fn pfile(path: impl Into, file_size: u64) -> PartitionedFile { + PartitionedFile::new(path, file_size) + } + + /// repartition the file groups both with and without preserving order + /// asserting they return the same value and returns that value + fn repartition_test( + partitioner: FileGroupPartitioner, + file_groups: Vec>, + ) -> Option>> { + let repartitioned = partitioner.repartition_file_groups(&file_groups); + + let repartitioned_preserving_sort = partitioner + .with_preserve_order_within_groups(true) + .repartition_file_groups(&file_groups); + + assert_partitioned_files( + repartitioned.clone(), + repartitioned_preserving_sort.clone(), + ); + repartitioned + } +} diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index d308397ab6e2..89694ff28500 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,15 +19,11 @@ //! file sources. use std::{ - borrow::Cow, cmp::min, collections::HashMap, fmt::Debug, marker::PhantomData, - sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, }; -use super::get_projected_output_ordering; -use crate::datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, -}; +use super::{get_projected_output_ordering, FileGroupPartitioner}; +use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -42,7 +38,6 @@ use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, Statistics}; use datafusion_physical_expr::LexOrdering; -use itertools::Itertools; use log::warn; /// Convert type to a type suitable for use as a [`ListingTable`] @@ -176,79 +171,17 @@ impl FileScanConfig { }) } - /// Repartition all input files into `target_partitions` partitions, if total file size exceed - /// `repartition_file_min_size` - /// `target_partitions` and `repartition_file_min_size` directly come from configuration. - /// - /// This function only try to partition file byte range evenly, and let specific `FileOpener` to - /// do actual partition on specific data source type. (e.g. `CsvOpener` will only read lines - /// overlap with byte range but also handle boundaries to ensure all lines will be read exactly once) + #[allow(missing_docs)] + #[deprecated(since = "33.0.0", note = "Use SessionContext::new_with_config")] pub fn repartition_file_groups( file_groups: Vec>, target_partitions: usize, repartition_file_min_size: usize, ) -> Option>> { - let flattened_files = file_groups.iter().flatten().collect::>(); - - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); - if has_ranges { - return None; - } - - let total_size = flattened_files - .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { - return None; - } - - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); - - let current_partition_index: usize = 0; - let current_partition_size: usize = 0; - - // Partition byte range evenly for all `PartitionedFile`s - let repartitioned_files = flattened_files - .into_iter() - .scan( - (current_partition_index, current_partition_size), - |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, - ); - - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.0, produced_file)); - - if state.1 + (range_end - range_start) >= target_partition_size { - state.0 += 1; - state.1 = 0; - } else { - state.1 += range_end - range_start; - } - range_start = range_end; - } - Some(produced_files) - }, - ) - .flatten() - .group_by(|(partition_idx, _)| *partition_idx) - .into_iter() - .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) - .collect_vec(); - - Some(repartitioned_files) + FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&file_groups) } } diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 14e550eab1d5..8e4dd5400b20 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -20,11 +20,13 @@ mod arrow_file; mod avro; mod csv; +mod file_groups; mod file_scan_config; mod file_stream; mod json; #[cfg(feature = "parquet")] pub mod parquet; +pub use file_groups::FileGroupPartitioner; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; @@ -537,7 +539,6 @@ mod tests { }; use arrow_schema::Field; use chrono::Utc; - use datafusion_common::config::ConfigOptions; use crate::physical_plan::{DefaultDisplay, VerboseDisplay}; @@ -809,345 +810,4 @@ mod tests { extensions: None, } } - - /// Unit tests for `repartition_file_groups()` - #[cfg(feature = "parquet")] - mod repartition_file_groups_test { - use datafusion_common::Statistics; - use itertools::Itertools; - - use super::*; - - /// Empty file won't get partitioned - #[tokio::test] - async fn repartition_empty_file_only() { - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - let file_group = vec![vec![partitioned_file_empty]]; - - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: file_group, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let partitioned_file = repartition_with_size(&parquet_exec, 4, 0); - - assert!(partitioned_file[0][0].range.is_none()); - } - - // Repartition when there is a empty file in file groups - #[tokio::test] - async fn repartition_empty_files() { - let partitioned_file_a = PartitionedFile::new("a".to_string(), 10); - let partitioned_file_b = PartitionedFile::new("b".to_string(), 10); - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - - let empty_first = vec![ - vec![partitioned_file_empty.clone()], - vec![partitioned_file_a.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_middle = vec![ - vec![partitioned_file_a.clone()], - vec![partitioned_file_empty.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_last = vec![ - vec![partitioned_file_a], - vec![partitioned_file_b], - vec![partitioned_file_empty], - ]; - - // Repartition file groups into x partitions - let expected_2 = - vec![(0, "a".to_string(), 0, 10), (1, "b".to_string(), 0, 10)]; - let expected_3 = vec![ - (0, "a".to_string(), 0, 7), - (1, "a".to_string(), 7, 10), - (1, "b".to_string(), 0, 4), - (2, "b".to_string(), 4, 10), - ]; - - //let file_groups_testset = [empty_first, empty_middle, empty_last]; - let file_groups_testset = [empty_first, empty_middle, empty_last]; - - for fg in file_groups_testset { - for (n_partition, expected) in [(2, &expected_2), (3, &expected_3)] { - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: fg.clone(), - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Arc::new( - Schema::empty(), - )), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = - repartition_with_size_to_vec(&parquet_exec, n_partition, 10); - - assert_eq!(expected, &actual); - } - } - } - - #[tokio::test] - async fn repartition_single_file() { - // Single file, single partition into multiple partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 4, 10); - let expected = vec![ - (0, "a".to_string(), 0, 31), - (1, "a".to_string(), 31, 62), - (2, "a".to_string(), 62, 93), - (3, "a".to_string(), 93, 123), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_too_much_partitions() { - // Single file, single parittion into 96 partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 8); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 96, 5); - let expected = vec![ - (0, "a".to_string(), 0, 1), - (1, "a".to_string(), 1, 2), - (2, "a".to_string(), 2, 3), - (3, "a".to_string(), 3, 4), - (4, "a".to_string(), 4, 5), - (5, "a".to_string(), 5, 6), - (6, "a".to_string(), 6, 7), - (7, "a".to_string(), 7, 8), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_multiple_partitions() { - // Multiple files in single partition after redistribution - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 3, 10); - let expected = vec![ - (0, "a".to_string(), 0, 34), - (1, "a".to_string(), 34, 40), - (1, "b".to_string(), 0, 28), - (2, "b".to_string(), 28, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_same_num_partitions() { - // "Rebalance" files across partitions - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 2, 10); - let expected = vec![ - (0, "a".to_string(), 0, 40), - (0, "b".to_string(), 0, 10), - (1, "b".to_string(), 10, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_no_action_ranges() { - // No action due to Some(range) in second file - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); - let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); - partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); - - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 10); - assert_eq!(2, actual.len()); - } - - #[tokio::test] - async fn repartition_no_action_min_size() { - // No action due to target_partition_size - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 500); - assert_eq!(1, actual.len()); - } - - /// Calls `ParquetExec.repartitioned` with the specified - /// `target_partitions` and `repartition_file_min_size`, returning the - /// resulting `PartitionedFile`s - fn repartition_with_size( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec> { - let mut config = ConfigOptions::new(); - config.optimizer.repartition_file_min_size = repartition_file_min_size; - - parquet_exec - .repartitioned(target_partitions, &config) - .unwrap() // unwrap Result - .unwrap() // unwrap Option - .as_any() - .downcast_ref::() - .unwrap() - .base_config() - .file_groups - .clone() - } - - /// Calls `repartition_with_size` and returns a tuple for each output `PartitionedFile`: - /// - /// `(partition index, file path, start, end)` - fn repartition_with_size_to_vec( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec<(usize, String, i64, i64)> { - let file_groups = repartition_with_size( - parquet_exec, - target_partitions, - repartition_file_min_size, - ); - - file_groups - .iter() - .enumerate() - .flat_map(|(part_idx, files)| { - files - .iter() - .map(|f| { - ( - part_idx, - f.object_meta.location.to_string(), - f.range.as_ref().unwrap().start, - f.range.as_ref().unwrap().end, - ) - }) - .collect_vec() - }) - .collect_vec() - } - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 847ea6505632..2b10b05a273a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -26,8 +26,8 @@ use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileMeta, FileScanConfig, - SchemaAdapter, + parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + FileMeta, FileScanConfig, SchemaAdapter, }; use crate::{ config::ConfigOptions, @@ -330,18 +330,18 @@ impl ExecutionPlan for ParquetExec { } /// Redistribute files across partitions according to their size - /// See comments on `get_file_groups_repartitioned()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .repartition_file_groups(&self.base_config.file_groups); let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index f2e04989ef66..099759741a10 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1761,6 +1761,7 @@ pub(crate) mod tests { parquet_exec_with_sort(vec![]) } + /// create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { @@ -1785,7 +1786,7 @@ pub(crate) mod tests { parquet_exec_multiple_sorted(vec![]) } - // Created a sorted parquet exec with multiple files + /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( output_ordering: Vec>, ) -> Arc { @@ -3858,6 +3859,56 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn parallelization_multiple_files() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + + let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key])); + let plan = sort_required_exec(plan); + + // The groups must have only contiguous ranges of rows from the same file + // if any group has rows from multiple files, the data is no longer sorted destroyed + // https://github.com/apache/arrow-datafusion/issues/8451 + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; + let target_partitions = 3; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let target_partitions = 8; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + Ok(()) + } + #[test] /// CsvExec on compressed csv file will not be partitioned /// (Not able to decompress chunked csv file) @@ -4529,15 +4580,11 @@ pub(crate) mod tests { assert_plan_txt!(expected, physical_plan); let expected = &[ - "SortRequiredExec: [a@0 ASC]", // Since at the start of the rule ordering requirement is satisfied // EnforceDistribution rule satisfy this requirement also. - // ordering is re-satisfied by introduction of SortExec. - "SortExec: expr=[a@0 ASC]", + "SortRequiredExec: [a@0 ASC]", "FilterExec: c@2 = 0", - // ordering is lost here - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + "ParquetExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; let mut config = ConfigOptions::new(); diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 551d6d9ed48a..5dcdbb504e76 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -118,7 +118,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 # Cleanup statement ok From fc6cc48e372b0c945aa78d78207441bca2bd11bf Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 17 Dec 2023 13:15:05 +0100 Subject: [PATCH 13/31] feat: support largelist in array_slice (#8561) * support largelist in array_slice * remove T trait * fix clippy --- .../physical-expr/src/array_expressions.rs | 110 ++++++++++----- datafusion/sqllogictest/test_files/array.slt | 129 ++++++++++++++++++ 2 files changed, 208 insertions(+), 31 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7fa97dad7aa6..7ccf58af832d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -524,11 +524,33 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = as_int64_array(&args[1])?; - let to_array = as_int64_array(&args[2])?; + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + } +} - let values = list_array.values(); +fn general_array_slice( + array: &GenericListArray, + from_array: &Int64Array, + to_array: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -539,72 +561,98 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { // We have the slice syntax compatible with DuckDB v0.8.1. // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. - fn adjusted_from_index(index: i64, len: usize) -> Option { + fn adjusted_from_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { - index + len as i64 + if let Ok(index) = index.try_into() { + index + len + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) - std::cmp::max(index - 1, 0) + if let Ok(index) = index.try_into() { + std::cmp::max(index - O::usize_as(1), O::usize_as(0)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - fn adjusted_to_index(index: i64, len: usize) -> Option { + fn adjusted_to_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive - index + len as i64 - 1 + if let Ok(index) = index.try_into() { + index + len - O::usize_as(1) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) - std::cmp::min(index - 1, len as i64 - 1) + if let Ok(index) = index.try_into() { + std::cmp::min(index - O::usize_as(1), len - O::usize_as(1)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - let mut offsets = vec![0]; + let mut offsets = vec![O::usize_as(0)]; - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // len 0 indicate array is null, return empty array in this row. - if len == 0 { + if len == O::usize_as(0) { offsets.push(offsets[row_index]); continue; } // If index is null, we consider it as the minimum / maximum index of the array. let from_index = if from_array.is_null(row_index) { - Some(0) + Some(O::usize_as(0)) } else { - adjusted_from_index(from_array.value(row_index), len) + adjusted_from_index::(from_array.value(row_index), len)? }; let to_index = if to_array.is_null(row_index) { - Some(len as i64 - 1) + Some(len - O::usize_as(1)) } else { - adjusted_to_index(to_array.value(row_index), len) + adjusted_to_index::(to_array.value(row_index), len)? }; if let (Some(from), Some(to)) = (from_index, to_index) { if from <= to { - assert!(start + to as usize <= end); - mutable.extend(0, start + from as usize, start + to as usize + 1); - offsets.push(offsets[row_index] + (to - from + 1) as i32); + assert!(start + to <= end); + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); } else { // invalid range, return empty array offsets.push(offsets[row_index]); @@ -617,9 +665,9 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::new(offsets.into()), + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), None, )?)) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1202a2b1e99d..210739aa51da 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -912,128 +912,235 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); ---- [4] [l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 4, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 3); +---- +[4] [l] + # array_slice scalar function #4 (with positive indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 4, 1); +---- +[] [] + # array_slice scalar function #5 (with positive indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); ---- [2, 3, 4, 5] [l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + # array_slice scalar function #6 (with positive indexes; nested array) query ? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); ---- [[1, 2, 3, 4, 5]] +query ? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1, 1); +---- +[[1, 2, 3, 4, 5]] + # array_slice scalar function #7 (with zero and positive number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); ---- [1, 2, 3, 4] [h, e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + # array_slice scalar function #8 (with NULL and positive number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); + # array_slice scalar function #9 (with positive number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); + # array_slice scalar function #10 (with zero-zero) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 0), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 0); +---- +[] [] + # array_slice scalar function #11 (with NULL-NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + + # array_slice scalar function #12 (with zero and negative number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- [1] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); +---- +[1] [h, e] + # array_slice scalar function #13 (with negative number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); + # array_slice scalar function #14 (with NULL and negative number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); + # array_slice scalar function #15 (with negative indexes) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- [2, 3, 4] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); +---- +[2, 3, 4] [l, l] + # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); +---- +[] [] + # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -6); +---- +[] [] + # array_slice scalar function #19 (with negative indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); +---- +[] [] + # array_slice scalar function #20 (with negative indexes; nested array) query ?? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- [[1, 2, 3, 4, 5]] [] +query ?? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + + # array_slice scalar function #21 (with first positive index and last negative index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); ---- [2] [e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); +---- +[2] [e, l] + # array_slice scalar function #22 (with first negative index and last positive index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); ---- [4, 5] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, 5), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, 4); +---- +[4, 5] [l, l] + # list_slice scalar function #23 (function alias `array_slice`) query ?? select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices; @@ -1046,6 +1153,17 @@ select array_slice(column1, column2, column3) from slices; [41, 42, 43, 44, 45, 46] [55, 56, 57, 58, 59, 60] +query ? +select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars query ??? @@ -1059,6 +1177,17 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +query ??? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), From b287cda40fa906dbdf035fa6a4dabe485927f42d Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 17 Dec 2023 22:55:31 -0800 Subject: [PATCH 14/31] minor: fix to support scalars (#8559) * minor: fix to support scalars * Update datafusion/sql/src/expr/function.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Co-authored-by: Andrew Lamb --- datafusion/sql/src/expr/function.rs | 3 ++ datafusion/sqllogictest/test_files/window.slt | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 73de4fa43907..3934d6701c63 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -90,6 +90,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let partition_by = window .partition_by .into_iter() + // ignore window spec PARTITION BY for scalar values + // as they do not change and thus do not generate new partitions + .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let mut order_by = self.order_by_to_sort_expr( diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 6198209aaac5..864f7dc0a47d 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3794,8 +3794,36 @@ select a, 1 1 2 1 +# support scalar value in ORDER BY query I select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x ---- 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, RANK function +# TODO: fix the test, some issue in RANK +#query IIIIII +#select rank() over (partition by 1 order by 1) rnk, +# rank() over (partition by a, 1 order by 1) rnk1, +# rank() over (partition by a, 1 order by a, 1) rnk2, +# rank() over (partition by 1) rnk3, +# rank() over (partition by null) rnk4, +# rank() over (partition by 1, null, a) rnk5 +#from (select 1 a union all select 2 a) x +#---- +#1 1 1 1 1 1 +#1 1 1 1 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function +query IIIIII +select row_number() over (partition by 1 order by 1) rn, + row_number() over (partition by a, 1 order by 1) rn1, + row_number() over (partition by a, 1 order by a, 1) rn2, + row_number() over (partition by 1) rn3, + row_number() over (partition by null) rn4, + row_number() over (partition by 1, null, a) rn5 +from (select 1 a union all select 2 a) x; +---- +1 1 1 1 1 1 +2 1 1 2 2 1 \ No newline at end of file From a71a76a996a32a0f068370940ebe475ec237b4ff Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Mon, 18 Dec 2023 11:53:26 +0200 Subject: [PATCH 15/31] refactor: `HashJoinStream` state machine (#8538) * hash join state machine * StreamJoinStateResult to StatefulStreamResult * doc comments & naming & fmt * suggestions from code review Co-authored-by: Andrew Lamb * more review comments addressed * post-merge fixes --------- Co-authored-by: Andrew Lamb --- .../physical-plan/src/joins/hash_join.rs | 431 ++++++++++++------ .../src/joins/stream_join_utils.rs | 127 ++---- .../src/joins/symmetric_hash_join.rs | 25 +- datafusion/physical-plan/src/joins/utils.rs | 83 ++++ 4 files changed, 420 insertions(+), 246 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 4846d0a5e046..13ac06ee301c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -28,7 +28,6 @@ use crate::joins::utils::{ calculate_join_output_ordering, get_final_indices_from_bit_map, need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, @@ -38,12 +37,13 @@ use crate::{ joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -618,15 +618,14 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), })) } @@ -789,6 +788,104 @@ where Ok(()) } +/// Represents build-side of hash join. +enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +struct BuildSideInitialState { + /// Future for building hash table from build-side input + left_fut: OnceFut, +} + +/// Container for BuildSide::Ready related data +struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, + /// Which build-side rows have been matched while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. + visited_left_side: BooleanBufferBuilder, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -808,20 +905,10 @@ struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future which builds hash table from left side - left_fut: OnceFut, - /// Which left (probe) side rows have been matches while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct output. - visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// The join output is complete. For outer joins, this is used to - /// distinguish when the input stream is exhausted and when any unmatched - /// rows are output. - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -830,6 +917,10 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, } impl RecordBatchStream for HashJoinStream { @@ -1069,19 +1160,44 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); @@ -1089,124 +1205,167 @@ impl HashJoinStream { self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { + let visited_left_side = if need_produce_result_in_final(self.join_type) { let num_rows = left_data.num_rows(); - if need_produce_result_in_final(self.join_type) { - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side.try_as_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; - // get next right (probe) input batch - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - left_data.hash_map(), - left_data.batch(), - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result - } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + ); + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - Some(err) => Some(err), - }) + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); + result + } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") + } + }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 2f74bd1c4bb2..64a976a1e39f 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,9 +23,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::utils::{JoinFilter, JoinHashMapType}; +use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, metrics}; +use crate::{handle_async_state, handle_state, metrics}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -624,73 +624,6 @@ pub fn record_visited_indices( } } -/// The `handle_state` macro is designed to process the result of a state-changing -/// operation, typically encountered in implementations of `EagerJoinStream`. It -/// operates on a `StreamJoinStateResult` by matching its variants and executing -/// corresponding actions. This macro is used to streamline code that deals with -/// state transitions, reducing boilerplate and improving readability. -/// -/// # Cases -/// -/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the -/// stream join operation should proceed to the next step. -/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the -/// result, either yielding a value or indicating the stream is awaiting more -/// data. -/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue -/// during the stream join operation. -/// -/// # Arguments -/// -/// * `$match_case`: An expression that evaluates to a `Result>`. -#[macro_export] -macro_rules! handle_state { - ($match_case:expr) => { - match $match_case { - Ok(StreamJoinStateResult::Continue) => continue, - Ok(StreamJoinStateResult::Ready(result)) => { - Poll::Ready(Ok(result).transpose()) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - }; -} - -/// The `handle_async_state` macro adapts the `handle_state` macro for use in -/// asynchronous operations, particularly when dealing with `Poll` results within -/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing -/// function using `poll_unpin` and then passes the result to `handle_state` for -/// further processing. -/// -/// # Arguments -/// -/// * `$state_func`: An async function or future that returns a -/// `Result>`. -/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. -/// -#[macro_export] -macro_rules! handle_async_state { - ($state_func:expr, $cx:expr) => { - $crate::handle_state!(ready!($state_func.poll_unpin($cx))) - }; -} - -/// Represents the result of a stateful operation on `EagerJoinStream`. -/// -/// This enumueration indicates whether the state produced a result that is -/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). -/// -/// Variants: -/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. -/// - `Continue`: Indicates that the operation is not yet complete and requires further -/// processing or more data. When this variant is returned, it typically means that the -/// current invocation of the state did not produce a final result, and the operation -/// should be invoked again later with more data and possibly with a different state. -pub enum StreamJoinStateResult { - Ready(T), - Continue, -} - /// Represents the various states of an eager join stream operation. /// /// This enum is used to track the current state of streaming during a join @@ -819,14 +752,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_right_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullLeft); @@ -835,7 +768,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::RightExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -848,14 +781,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_left_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullRight); self.process_batch_from_left(batch) @@ -863,7 +796,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::LeftExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -877,14 +810,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_right_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_right_end(batch) } @@ -893,7 +826,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -907,14 +840,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_left_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_left_end(batch) } @@ -923,7 +856,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -936,10 +869,10 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after both streams are exhausted. + /// * `Result>>` - The state result after both streams are exhausted. fn prepare_for_final_results_after_exhaustion( &mut self, - ) -> Result>> { + ) -> Result>> { self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); self.process_batches_before_finalization() } @@ -952,11 +885,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles a pulled batch from the left stream. /// @@ -966,11 +899,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the left stream is exhausted. /// @@ -980,11 +913,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the left stream is exhausted. + /// * `Result>>` - The state result after the left stream is exhausted. fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the right stream is exhausted. /// @@ -994,20 +927,20 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the right stream is exhausted. + /// * `Result>>` - The state result after the right stream is exhausted. fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the final state after both streams are exhausted. /// /// # Returns /// - /// * `Result>>` - The final state result after processing. + /// * `Result>>` - The final state result after processing. fn process_batches_before_finalization( &mut self, - ) -> Result>>; + ) -> Result>>; /// Provides mutable access to the right stream. /// diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 00a7f23ebae7..b9101b57c3e5 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -38,12 +38,11 @@ use crate::joins::stream_join_utils::{ convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, record_visited_indices, EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, - StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, + JoinOn, StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, @@ -956,13 +955,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Right) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -970,13 +969,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Left) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -984,20 +983,20 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_right(right_batch) } fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_left(left_batch) } fn process_batches_before_finalization( &mut self, - ) -> Result>> { + ) -> Result>> { // Get the left side results: let left_result = build_side_determined_results( &self.left, @@ -1025,9 +1024,9 @@ impl EagerJoinStream for SymmetricHashJoinStream { // Update the metrics: self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); - return Ok(StreamJoinStateResult::Ready(result)); + return Ok(StatefulStreamResult::Ready(result)); } - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } fn right_stream(&mut self) -> &mut SendableRecordBatchStream { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5e01ca227cf5..eae65ce9c26b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -849,6 +849,22 @@ impl OnceFut { ), } } + + /// Get shared reference to the result of the computation if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and @@ -1277,6 +1293,73 @@ pub fn prepare_sorted_exprs( Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) } +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, encountered e.g. in implementations of `EagerJoinStream`. It +/// operates on a `StatefulStreamResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. +/// +/// # Cases +/// +/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; +} + +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. +/// +/// # Arguments +/// +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of an operation on stateful join stream. +/// +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StatefulStreamResult { + Ready(T), + Continue, +} + #[cfg(test)] mod tests { use std::pin::Pin; From a1e959d87a66da7060bd005b1993b824c0683a63 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 18 Dec 2023 10:55:49 +0000 Subject: [PATCH 16/31] Remove ListingTable and FileScanConfig Unbounded (#8540) (#8573) * Remove ListingTable and FileScanConfig Unbounded (#8540) * Fix substrait * Fix logical conflicts * Add deleted tests as ignored --------- Co-authored-by: Mustafa Akur --- datafusion-examples/examples/csv_opener.rs | 1 - datafusion-examples/examples/json_opener.rs | 1 - .../core/src/datasource/file_format/mod.rs | 1 - .../src/datasource/file_format/options.rs | 48 ++---- .../core/src/datasource/listing/table.rs | 152 ------------------ .../src/datasource/listing_table_factory.rs | 16 +- .../datasource/physical_plan/arrow_file.rs | 4 - .../core/src/datasource/physical_plan/avro.rs | 7 - .../core/src/datasource/physical_plan/csv.rs | 4 - .../physical_plan/file_scan_config.rs | 3 - .../datasource/physical_plan/file_stream.rs | 1 - .../core/src/datasource/physical_plan/json.rs | 8 - .../core/src/datasource/physical_plan/mod.rs | 4 - .../datasource/physical_plan/parquet/mod.rs | 4 - datafusion/core/src/execution/context/mod.rs | 11 +- .../combine_partial_final_agg.rs | 1 - .../enforce_distribution.rs | 5 - .../src/physical_optimizer/enforce_sorting.rs | 15 +- .../physical_optimizer/projection_pushdown.rs | 2 - .../replace_with_order_preserving_variants.rs | 92 ++++++----- .../core/src/physical_optimizer/test_utils.rs | 24 +-- datafusion/core/src/test/mod.rs | 3 - datafusion/core/src/test_util/mod.rs | 25 +-- datafusion/core/src/test_util/parquet.rs | 1 - .../core/tests/parquet/custom_reader.rs | 1 - datafusion/core/tests/parquet/page_pruning.rs | 1 - .../core/tests/parquet/schema_coercion.rs | 2 - datafusion/core/tests/sql/joins.rs | 42 ++--- .../proto/src/physical_plan/from_proto.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - .../substrait/src/physical_plan/consumer.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - 32 files changed, 102 insertions(+), 381 deletions(-) diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index 15fb07ded481..96753c8c5260 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -67,7 +67,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index 1a3dbe57be75..ee33f969caa9 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -70,7 +70,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 7c2331548e5e..12c9fb91adb1 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -165,7 +165,6 @@ pub(crate) mod test_util { limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, ) diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 4c7557a4a9c0..d389137785ff 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{plan_err, DataFusionError}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -72,8 +71,6 @@ pub struct CsvReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, /// File compression type pub file_compression_type: FileCompressionType, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, } @@ -97,7 +94,6 @@ impl<'a> CsvReadOptions<'a> { file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, - infinite: false, file_sort_order: vec![], } } @@ -108,12 +104,6 @@ impl<'a> CsvReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify delimiter to use for CSV read pub fn delimiter(mut self, delimiter: u8) -> Self { self.delimiter = delimiter; @@ -324,8 +314,6 @@ pub struct AvroReadOptions<'a> { pub file_extension: &'a str, /// Partition Columns pub table_partition_cols: Vec<(String, DataType)>, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, } impl<'a> Default for AvroReadOptions<'a> { @@ -334,7 +322,6 @@ impl<'a> Default for AvroReadOptions<'a> { schema: None, file_extension: DEFAULT_AVRO_EXTENSION, table_partition_cols: vec![], - infinite: false, } } } @@ -349,12 +336,6 @@ impl<'a> AvroReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify schema to use for AVRO read pub fn schema(mut self, schema: &'a Schema) -> Self { self.schema = Some(schema); @@ -466,21 +447,17 @@ pub trait ReadOptions<'a> { state: SessionState, table_path: ListingTableUrl, schema: Option<&'a Schema>, - infinite: bool, ) -> Result where 'a: 'async_trait, { - match (schema, infinite) { - (Some(s), _) => Ok(Arc::new(s.to_owned())), - (None, false) => Ok(self - .to_listing_options(config) - .infer_schema(&state, &table_path) - .await?), - (None, true) => { - plan_err!("Schema inference for infinite data sources is not supported.") - } + if let Some(s) = schema { + return Ok(Arc::new(s.to_owned())); } + + self.to_listing_options(config) + .infer_schema(&state, &table_path) + .await } } @@ -500,7 +477,6 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -509,7 +485,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -535,7 +511,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -551,7 +527,6 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) .with_file_sort_order(self.file_sort_order.clone()) } @@ -561,7 +536,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -575,7 +550,6 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -584,7 +558,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -606,7 +580,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0ce1b43fe456..4c13d9d443ca 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -246,11 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// Infinite source means that the input is not guaranteed to end. - /// Currently, CSV, JSON, and AVRO formats are supported. - /// In order to support infinite inputs, DataFusion may adjust query - /// plans (e.g. joins) to run the given query in full pipelining mode. - pub infinite_source: bool, /// This setting when true indicates that the table is backed by a single file. /// Any inserts to the table may only append to this existing file. pub single_file: bool, @@ -274,30 +269,11 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - infinite_source: false, single_file: false, file_type_write_options: None, } } - /// Set unbounded assumption on [`ListingOptions`] and returns self. - /// - /// ``` - /// use std::sync::Arc; - /// use datafusion::datasource::{listing::ListingOptions, file_format::csv::CsvFormat}; - /// use datafusion::prelude::SessionContext; - /// let ctx = SessionContext::new(); - /// let listing_options = ListingOptions::new(Arc::new( - /// CsvFormat::default() - /// )).with_infinite_source(true); - /// - /// assert_eq!(listing_options.infinite_source, true); - /// ``` - pub fn with_infinite_source(mut self, infinite_source: bool) -> Self { - self.infinite_source = infinite_source; - self - } - /// Set file extension on [`ListingOptions`] and returns self. /// /// ``` @@ -557,7 +533,6 @@ pub struct ListingTable { options: ListingOptions, definition: Option, collected_statistics: FileStatisticsCache, - infinite_source: bool, constraints: Constraints, column_defaults: HashMap, } @@ -587,7 +562,6 @@ impl ListingTable { for (part_col_name, part_col_type) in &options.table_partition_cols { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } - let infinite_source = options.infinite_source; let table = Self { table_paths: config.table_paths, @@ -596,7 +570,6 @@ impl ListingTable { options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - infinite_source, constraints: Constraints::empty(), column_defaults: HashMap::new(), }; @@ -729,7 +702,6 @@ impl TableProvider for ListingTable { limit, output_ordering: self.try_create_output_ordering()?, table_partition_cols, - infinite_source: self.infinite_source, }, filters.as_ref(), ) @@ -943,7 +915,6 @@ impl ListingTable { #[cfg(test)] mod tests { use std::collections::HashMap; - use std::fs::File; use super::*; #[cfg(feature = "parquet")] @@ -955,7 +926,6 @@ mod tests { use crate::{ assert_batches_eq, datasource::file_format::avro::AvroFormat, - execution::options::ReadOptions, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, }; @@ -967,37 +937,8 @@ mod tests { use datafusion_common::{assert_contains, GetExt, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; - use rstest::*; use tempfile::TempDir; - /// It creates dummy file and checks if it can create unbounded input executors. - async fn unbounded_table_helper( - file_type: FileType, - listing_option: ListingOptions, - infinite_data: bool, - ) -> Result<()> { - let ctx = SessionContext::new(); - register_test_store( - &ctx, - &[(&format!("table/file{}", file_type.get_ext()), 100)], - ); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_option) - .with_schema(Arc::new(schema)); - // Create a table - let table = ListingTable::try_new(config)?; - // Create executor from table - let source_exec = table.scan(&ctx.state(), None, &[], None).await?; - - assert_eq!(source_exec.unbounded_output(&[])?, infinite_data); - - Ok(()) - } - #[tokio::test] async fn read_single_file() -> Result<()> { let ctx = SessionContext::new(); @@ -1205,99 +1146,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn unbounded_csv_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.csv"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_json_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.json"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_json( - "test", - tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_avro_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.avro"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_avro( - "test", - tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[rstest] - #[tokio::test] - async fn unbounded_csv_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = CsvReadOptions::new().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::CSV, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_json_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = NdJsonReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::JSON, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_avro_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = AvroReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::AVRO, listing_options, infinite_data).await - } - #[tokio::test] async fn test_assert_list_files_for_scan_grouping() -> Result<()> { // more expected partitions than files diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index a9d0c3a0099e..7c859ee988d5 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -133,21 +133,9 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - // look for 'infinite' as an option - let infinite_source = cmd.unbounded; - let mut statement_options = StatementOptions::from(&cmd.options); // Extract ListingTable specific options if present or set default - let unbounded = if infinite_source { - statement_options.take_str_option("unbounded"); - infinite_source - } else { - statement_options - .take_bool_option("unbounded")? - .unwrap_or(false) - }; - let single_file = statement_options .take_bool_option("single_file")? .unwrap_or(false); @@ -159,6 +147,7 @@ impl TableProviderFactory for ListingTableFactory { } } statement_options.take_bool_option("create_local_path")?; + statement_options.take_str_option("unbounded"); let file_type = file_format.file_type(); @@ -207,8 +196,7 @@ impl TableProviderFactory for ListingTableFactory { .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) .with_single_file(single_file) - .with_write_options(file_type_writer_options) - .with_infinite_source(unbounded); + .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { None => options.infer_schema(state, &table_path).await?, diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 30b55db28491..ae1e879d0da1 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -93,10 +93,6 @@ impl ExecutionPlan for ArrowExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 885b4c5d3911..e448bf39f427 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -89,10 +89,6 @@ impl ExecutionPlan for AvroExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -276,7 +272,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); let mut results = avro_exec @@ -348,7 +343,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); @@ -419,7 +413,6 @@ mod tests { limit: None, table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0eca37da139d..0c34d22e9fa9 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -146,10 +146,6 @@ impl ExecutionPlan for CsvExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - /// See comments on `impl ExecutionPlan for ParquetExec`: output order can't be fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 89694ff28500..516755e4d293 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -99,8 +99,6 @@ pub struct FileScanConfig { pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, - /// Indicates whether this plan may produce an infinite stream of records. - pub infinite_source: bool, } impl FileScanConfig { @@ -707,7 +705,6 @@ mod tests { statistics, table_partition_cols, output_ordering: vec![], - infinite_source: false, } } diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index a715f6e8e3cd..99fb088b66f4 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -667,7 +667,6 @@ mod tests { limit: self.limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let metrics_set = ExecutionPlanMetricsSet::new(); let file_stream = FileStream::new(&config, 0, self.opener, &metrics_set) diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 9c3b523a652c..c74fd13e77aa 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -110,10 +110,6 @@ impl ExecutionPlan for NdJsonExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config.infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -462,7 +458,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -541,7 +536,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -589,7 +583,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -642,7 +635,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 8e4dd5400b20..9d1c373aee7c 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -133,10 +133,6 @@ impl DisplayAs for FileScanConfig { write!(f, ", limit={limit}")?; } - if self.infinite_source { - write!(f, ", infinite_source=true")?; - } - if let Some(ordering) = orderings.first() { if !ordering.is_empty() { let start = if orderings.len() == 1 { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 2b10b05a273a..ade149da6991 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -882,7 +882,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, predicate, None, @@ -1539,7 +1538,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1654,7 +1652,6 @@ mod tests { ), ], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1718,7 +1715,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 58a4f08341d6..8916fa814a4a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -964,14 +964,9 @@ impl SessionContext { sql_definition: Option, ) -> Result<()> { let table_path = ListingTableUrl::parse(table_path)?; - let resolved_schema = match (provided_schema, options.infinite_source) { - (Some(s), _) => s, - (None, false) => options.infer_schema(&self.state(), &table_path).await?, - (None, true) => { - return plan_err!( - "Schema inference for infinite data sources is not supported." - ) - } + let resolved_schema = match provided_schema { + Some(s) => s, + None => options.infer_schema(&self.state(), &table_path).await?, }; let config = ListingTableConfig::new(table_path) .with_listing_options(options) diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index c50ea36b68ec..7359a6463059 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -257,7 +257,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 099759741a10..0aef126578f3 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1775,7 +1775,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1803,7 +1802,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1825,7 +1823,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1856,7 +1853,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -3957,7 +3953,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, false, b',', diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 277404b301c4..c0e9b834e66f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -2117,7 +2117,7 @@ mod tests { async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2141,10 +2141,11 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_with_lost_ordering_unbounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2171,10 +2172,12 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + // Make source unbounded + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2203,7 +2206,7 @@ mod tests { async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); @@ -2224,7 +2227,7 @@ mod tests { async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec( @@ -2252,7 +2255,7 @@ mod tests { async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, vec![], false); + let source = csv_exec_sorted(&schema, vec![]); let sort = sort_exec(sort_exprs.clone(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 664afbe822ff..7e1312dad23e 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1541,7 +1541,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![vec![]], - infinite_source: false, }, false, 0, @@ -1568,7 +1567,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![vec![]], - infinite_source: false, }, false, 0, diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index af45df7d8474..41f2b39978a4 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -350,7 +350,7 @@ mod tests { async fn test_replace_multiple_input_repartition_1() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -362,15 +362,15 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -378,7 +378,7 @@ mod tests { async fn test_with_inter_children_change_only() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +408,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; let expected_optimized = [ @@ -419,9 +419,9 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -429,7 +429,7 @@ mod tests { async fn test_replace_multiple_input_repartition_2() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -444,16 +444,16 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -461,7 +461,7 @@ mod tests { async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -478,7 +478,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -486,9 +486,9 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -496,7 +496,7 @@ mod tests { async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -516,7 +516,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -525,9 +525,9 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -535,7 +535,7 @@ mod tests { async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -550,7 +550,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "CoalescePartitionsExec", @@ -558,7 +558,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -568,7 +568,7 @@ mod tests { async fn test_with_multiple_replacable_repartitions() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -587,7 +587,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -596,9 +596,9 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -606,7 +606,7 @@ mod tests { async fn test_not_replace_with_different_orderings() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -625,14 +625,14 @@ mod tests { " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -642,7 +642,7 @@ mod tests { async fn test_with_lost_ordering() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -654,15 +654,15 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -670,7 +670,7 @@ mod tests { async fn test_with_lost_and_kept_ordering() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -700,7 +700,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ @@ -712,9 +712,9 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -723,14 +723,14 @@ mod tests { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs, true); + let left_source = csv_exec_sorted(&schema, left_sort_exprs); let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs, true); + let right_source = csv_exec_sorted(&schema, right_sort_exprs); let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -756,11 +756,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ @@ -770,11 +770,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -784,7 +784,7 @@ mod tests { async fn test_with_bounded_input() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -931,7 +931,6 @@ mod tests { fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; @@ -949,7 +948,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, true, 0, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 678dc1f373e3..6e14cca21fed 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -45,6 +45,7 @@ use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use crate::datasource::stream::{StreamConfig, StreamTable}; use async_trait::async_trait; async fn register_current_csv( @@ -54,14 +55,19 @@ async fn register_current_csv( ) -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let schema = crate::test_util::aggr_test_schema(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new() - .schema(&schema) - .mark_infinite(infinite), - ) - .await?; + let path = format!("{testdata}/csv/aggregate_test_100.csv"); + + match infinite { + true => { + let config = StreamConfig::new_file(schema, path.into()); + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + } + false => { + ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) + .await?; + } + } + Ok(()) } @@ -272,7 +278,6 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -296,7 +301,6 @@ pub fn parquet_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index aad5c19044ea..8770c0c4238a 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -203,7 +203,6 @@ pub fn partitioned_csv_config( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }) } @@ -277,7 +276,6 @@ fn make_decimal() -> RecordBatch { pub fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); @@ -291,7 +289,6 @@ pub fn csv_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, false, 0, diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index c6b43de0c18d..282b0f7079ee 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -36,7 +36,6 @@ use crate::datasource::provider::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; -use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -58,6 +57,7 @@ use futures::Stream; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::datasource::stream::{StreamConfig, StreamTable}; pub use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq}; /// Scan an empty data source, mainly used in tests @@ -342,30 +342,17 @@ impl RecordBatchStream for UnboundedStream { } /// This function creates an unbounded sorted file for testing purposes. -pub async fn register_unbounded_file_with_ordering( +pub fn register_unbounded_file_with_ordering( ctx: &SessionContext, schema: SchemaRef, file_path: &Path, table_name: &str, file_sort_order: Vec>, - with_unbounded_execution: bool, ) -> Result<()> { - // Mark infinite and provide schema: - let fifo_options = CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(with_unbounded_execution); - // Get listing options: - let options_sort = fifo_options - .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(file_sort_order); + let config = + StreamConfig::new_file(schema, file_path.into()).with_order(file_sort_order); + // Register table: - ctx.register_listing_table( - table_name, - file_path.as_os_str().to_str().unwrap(), - options_sort, - Some(schema), - None, - ) - .await?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index f3c0d2987a46..336a6804637a 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -156,7 +156,6 @@ impl TestParquetFile { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let df_schema = self.schema.clone().to_dfschema_ref()?; diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 3752d42dbf43..e76b201e0222 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -85,7 +85,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index e1e8b8e66edd..23a56bc821d4 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -81,7 +81,6 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, Some(predicate), None, diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 25c62f18f5ba..00f3eada496e 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -69,7 +69,6 @@ async fn multi_parquet_coercion() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -133,7 +132,6 @@ async fn multi_parquet_coercion_projection() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 528bde632355..d1f270b540b5 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::datasource::stream::{StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; use super::*; @@ -105,9 +106,7 @@ async fn join_change_in_planner() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - true, - ) - .await?; + )?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -116,9 +115,7 @@ async fn join_change_in_planner() -> Result<()> { &right_file_path, "right", file_sort_order, - true, - ) - .await?; + )?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -160,20 +157,13 @@ async fn join_change_in_planner_without_sort() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema, right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -217,20 +207,12 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema.clone(), right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index dcebfbf2dabb..5c0ef615cacd 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -526,7 +526,6 @@ pub fn parse_protobuf_file_scan_config( limit: proto.limit.as_ref().map(|sl| sl.limit as usize), table_partition_cols, output_ordering, - infinite_source: false, }) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4a512413e73e..9a9827f2a090 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -492,7 +492,6 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let predicate = Arc::new(BinaryExpr::new( diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 942798173e0e..3098dc386e6a 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -112,7 +112,6 @@ pub async fn from_substrait_rel( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; if let Some(MaskExpression { select, .. }) = &read.projection { diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index b64dd2c138fc..e5af3f94cc05 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -49,7 +49,6 @@ async fn parquet_exec() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let parquet_exec: Arc = Arc::new(ParquetExec::new(scan_config, None, None)); From d65b51a4d5fef13135b900249a4f7934b1098339 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Dec 2023 14:00:39 -0500 Subject: [PATCH 17/31] Update substrait requirement from 0.20.0 to 0.21.0 (#8574) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.20.0...v0.21.0) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 42ebe56c298b..0a9a6e8dd12b 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,7 +35,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.20.0" +substrait = "0.21.0" tokio = "1.17" [features] From ceead1cc48fd903bd877bb45e258b8ccc12e5b30 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:01:50 +0300 Subject: [PATCH 18/31] [minor]: Fix rank calculation bug when empty order by is seen (#8567) * minor: fix to support scalars * Fix empty order by rank implementation --------- Co-authored-by: comphead --- datafusion/physical-expr/src/window/rank.rs | 13 ++++++-- .../physical-expr/src/window/window_expr.rs | 2 +- datafusion/sqllogictest/test_files/window.slt | 30 +++++++++++-------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 9bc36728f46e..86af5b322133 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -141,9 +141,16 @@ impl PartitionEvaluator for RankEvaluator { // There is no argument, values are order by column values (where rank is calculated) let range_columns = values; let last_rank_data = get_row_at_idx(range_columns, row_idx)?; - let empty = self.state.last_rank_data.is_empty(); - if empty || self.state.last_rank_data != last_rank_data { - self.state.last_rank_data = last_rank_data; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); self.state.last_rank_boundary += self.state.current_group_count; self.state.current_group_count = 1; self.state.n_rank += 1; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 4211a616e100..548fae75bd97 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -274,7 +274,7 @@ pub enum WindowFn { #[derive(Debug, Clone, Default)] pub struct RankState { /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Vec, + pub last_rank_data: Option>, /// The index where last_rank_boundary is started pub last_rank_boundary: usize, /// Keep the number of entries in current rank diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 864f7dc0a47d..aa083290b4f4 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3801,19 +3801,25 @@ select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x 1 1 +# support scalar value in ORDER BY +query I +select dense_rank() over () rnk from (select 1 a union all select 2 a) x +---- +1 +1 + # support scalar value in both ORDER BY and PARTITION BY, RANK function -# TODO: fix the test, some issue in RANK -#query IIIIII -#select rank() over (partition by 1 order by 1) rnk, -# rank() over (partition by a, 1 order by 1) rnk1, -# rank() over (partition by a, 1 order by a, 1) rnk2, -# rank() over (partition by 1) rnk3, -# rank() over (partition by null) rnk4, -# rank() over (partition by 1, null, a) rnk5 -#from (select 1 a union all select 2 a) x -#---- -#1 1 1 1 1 1 -#1 1 1 1 1 1 +query IIIIII +select rank() over (partition by 1 order by 1) rnk, + rank() over (partition by a, 1 order by 1) rnk1, + rank() over (partition by a, 1 order by a, 1) rnk2, + rank() over (partition by 1) rnk3, + rank() over (partition by null) rnk4, + rank() over (partition by 1, null, a) rnk5 +from (select 1 a union all select 2 a) x +---- +1 1 1 1 1 1 +1 1 1 1 1 1 # support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function query IIIIII From b5e94a688e3a66325cc6ed9b2e35b44cf6cd9ba8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Dec 2023 14:08:33 -0500 Subject: [PATCH 19/31] Add `LiteralGuarantee` on columns to extract conditions required for `PhysicalExpr` expressions to evaluate to true (#8437) * Introduce LiteralGurantee to find col=const * Improve comments * Improve documentation * Add more documentation and tests * refine documentation and tests * Apply suggestions from code review Co-authored-by: Nga Tran * Fix half comment * swap operators before analysis * More tests * cmt * Apply suggestions from code review Co-authored-by: Ruihang Xia * refine comments more --------- Co-authored-by: Nga Tran Co-authored-by: Ruihang Xia --- .../physical-expr/src/utils/guarantee.rs | 709 ++++++++++++++++++ .../src/{utils.rs => utils/mod.rs} | 33 +- 2 files changed, 729 insertions(+), 13 deletions(-) create mode 100644 datafusion/physical-expr/src/utils/guarantee.rs rename datafusion/physical-expr/src/{utils.rs => utils/mod.rs} (96%) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs new file mode 100644 index 000000000000..59ec255754c0 --- /dev/null +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -0,0 +1,709 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LiteralGuarantee`] predicate analysis to determine if a column is a +//! constant. + +use crate::utils::split_disjunction; +use crate::{split_conjunction, PhysicalExpr}; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::Operator; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +/// Represents a guarantee that must be true for a boolean expression to +/// evaluate to `true`. +/// +/// The guarantee takes the form of a column and a set of literal (constant) +/// [`ScalarValue`]s. For the expression to evaluate to `true`, the column *must +/// satisfy* the guarantee(s). +/// +/// To satisfy the guarantee, depending on [`Guarantee`], the values in the +/// column must either: +/// +/// 1. be ONLY one of that set +/// 2. NOT be ANY of that set +/// +/// # Uses `LiteralGuarantee`s +/// +/// `LiteralGuarantee`s can be used to simplify filter expressions and skip data +/// files (e.g. row groups in parquet files) by proving expressions can not +/// possibly evaluate to `true`. For example, if we have a guarantee that `a` +/// must be in (`1`) for a filter to evaluate to `true`, then we can skip any +/// partition where we know that `a` never has the value of `1`. +/// +/// **Important**: If a `LiteralGuarantee` is not satisfied, the relevant +/// expression is *guaranteed* to evaluate to `false` or `null`. **However**, +/// the opposite does not hold. Even if all `LiteralGuarantee`s are satisfied, +/// that does **not** guarantee that the predicate will actually evaluate to +/// `true`: it may still evaluate to `true`, `false` or `null`. +/// +/// # Creating `LiteralGuarantee`s +/// +/// Use [`LiteralGuarantee::analyze`] to extract literal guarantees from a +/// filter predicate. +/// +/// # Details +/// A guarantee can be one of two forms: +/// +/// 1. The column must be one the values for the predicate to be `true`. If the +/// column takes on any other value, the predicate can not evaluate to `true`. +/// For example, +/// `(a = 1)`, `(a = 1 OR a = 2) or `a IN (1, 2, 3)` +/// +/// 2. The column must NOT be one of the values for the predicate to be `true`. +/// If the column can ONLY take one of these values, the predicate can not +/// evaluate to `true`. For example, +/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralGuarantee { + pub column: Column, + pub guarantee: Guarantee, + pub literals: HashSet, +} + +/// What is guaranteed about the values for a [`LiteralGuarantee`]? +#[derive(Debug, Clone, PartialEq)] +pub enum Guarantee { + /// Guarantee that the expression is `true` if `column` is one of the values. If + /// `column` is not one of the values, the expression can not be `true`. + In, + /// Guarantee that the expression is `true` if `column` is not ANY of the + /// values. If `column` only takes one of these values, the expression can + /// not be `true`. + NotIn, +} + +impl LiteralGuarantee { + /// Create a new instance of the guarantee if the provided operator is + /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to + /// create these structures from an predicate (boolean expression). + fn try_new<'a>( + column_name: impl Into, + op: Operator, + literals: impl IntoIterator, + ) -> Option { + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; + + let literals: HashSet<_> = literals.into_iter().cloned().collect(); + + Some(Self { + column: Column::from_name(column_name), + guarantee, + literals, + }) + } + + /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` + /// to evaluate to `true`. + /// + /// If more than one `LiteralGuarantee` is returned, they must **all** hold + /// for the expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. + /// + /// # Notes: + /// 1. `expr` must be a boolean expression. + /// 2. `expr` is not simplified prior to analysis. + pub fn analyze(expr: &Arc) -> Vec { + // split conjunction: AND AND ... + split_conjunction(expr) + .into_iter() + // for an `AND` conjunction to be true, all terms individually must be true + .fold(GuaranteeBuilder::new(), |builder, expr| { + if let Some(cel) = ColOpLit::try_new(expr) { + return builder.aggregate_conjunct(cel); + } else { + // split disjunction: OR OR ... + let disjunctions = split_disjunction(expr); + + // We are trying to add a guarantee that a column must be + // in/not in a particular set of values for the expression + // to evaluate to true. + // + // A disjunction is true, if at least one of the terms is be + // true. + // + // Thus, we can infer a guarantee if all terms are of the + // form `(col literal) OR (col literal) OR ...`. + // + // For example, we can infer that `a = 1 OR a = 2 OR a = 3` + // is guaranteed to be true ONLY if a is in (`1`, `2` or `3`). + // + // However, for something like `a = 1 OR a = 2 OR a < 0` we + // **can't** guarantee that the predicate is only true if a + // is in (`1`, `2`), as it could also be true if `a` were less + // than zero. + let terms = disjunctions + .iter() + .filter_map(|expr| ColOpLit::try_new(expr)) + .collect::>(); + + if terms.is_empty() { + return builder; + } + + // if not all terms are of the form (col literal), + // can't infer any guarantees + if terms.len() != disjunctions.len() { + return builder; + } + + // if all terms are 'col literal' with the same column + // and operation we can infer any guarantees + let first_term = &terms[0]; + if terms.iter().all(|term| { + term.col.name() == first_term.col.name() + && term.op == first_term.op + }) { + builder.aggregate_multi_conjunct( + first_term.col, + first_term.op, + terms.iter().map(|term| term.lit.value()), + ) + } else { + // can't infer anything + builder + } + } + }) + .build() + } +} + +/// Combines conjuncts (aka terms `AND`ed together) into [`LiteralGuarantee`]s, +/// preserving insert order +#[derive(Debug, Default)] +struct GuaranteeBuilder<'a> { + /// List of guarantees that have been created so far + /// if we have determined a subsequent conjunct invalidates a guarantee + /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None + guarantees: Vec>, + + /// Key is the (column name, operator type) + /// Value is the index into `guarantees` + map: HashMap<(&'a crate::expressions::Column, Operator), usize>, +} + +impl<'a> GuaranteeBuilder<'a> { + fn new() -> Self { + Default::default() + } + + /// Aggregate a new single `AND col literal` term to this builder + /// combining with existing guarantees if possible. + /// + /// # Examples + /// * `AND (a = 1)`: `a` is guaranteed to be 1 + /// * `AND (a != 1)`: a is guaranteed to not be 1 + fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { + self.aggregate_multi_conjunct( + col_op_lit.col, + col_op_lit.op, + [col_op_lit.lit.value()], + ) + } + + /// Aggregates a new single column, multi literal term to ths builder + /// combining with previously known guarantees if possible. + /// + /// # Examples + /// For the following examples, we can guarantee the expression is `true` if: + /// * `AND (a = 1 OR a = 2 OR a = 3)`: a is in (1, 2, or 3) + /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) + /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) + /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + fn aggregate_multi_conjunct( + mut self, + col: &'a crate::expressions::Column, + op: Operator, + new_values: impl IntoIterator, + ) -> Self { + let key = (col, op); + if let Some(index) = self.map.get(&key) { + // already have a guarantee for this column + let entry = &mut self.guarantees[*index]; + + let Some(existing) = entry else { + // determined the previous guarantee for this column has been + // invalidated, nothing to do + return self; + }; + + // Combine conjuncts if we have `a != foo AND a != bar`. `a = foo + // AND a = bar` doesn't make logical sense so we don't optimize this + // case + match existing.guarantee { + // knew that the column could not be a set of values + // + // For example, if we previously had `a != 5` and now we see + // another `AND a != 6` we know that a must not be either 5 or 6 + // for the expression to be true + Guarantee::NotIn => { + // can extend if only single literal, otherwise invalidate + let new_values: HashSet<_> = new_values.into_iter().collect(); + if new_values.len() == 1 { + existing.literals.extend(new_values.into_iter().cloned()) + } else { + // this is like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true + } + } + Guarantee::In => { + // for an IN guarantee, it is ok if the value is the same + // e.g. `a = foo AND a = foo` but not if the value is different + // e.g. `a = foo AND a = bar` + if new_values + .into_iter() + .all(|new_value| existing.literals.contains(new_value)) + { + // all values are already in the set + } else { + // at least one was not, so invalidate the guarantee + *entry = None; + } + } + } + } else { + // This is a new guarantee + let new_values: HashSet<_> = new_values.into_iter().collect(); + + // new_values are combined with OR, so we can only create a + // multi-column guarantee for `=` (or a single value). + // (e.g. ignore `a != foo OR a != bar`) + if op == Operator::Eq || new_values.len() == 1 { + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), op, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); + } + } + } + + self + } + + /// Return all guarantees that have been created so far + fn build(self) -> Vec { + // filter out any guarantees that have been invalidated + self.guarantees.into_iter().flatten().collect() + } +} + +/// Represents a single `col literal` expression +struct ColOpLit<'a> { + col: &'a crate::expressions::Column, + op: Operator, + lit: &'a crate::expressions::Literal, +} + +impl<'a> ColOpLit<'a> { + /// Returns Some(ColEqLit) if the expression is either: + /// 1. `col literal` + /// 2. `literal col` + /// + /// Returns None otherwise + fn try_new(expr: &'a Arc) -> Option { + let binary_expr = expr + .as_any() + .downcast_ref::()?; + + let (left, op, right) = ( + binary_expr.left().as_any(), + binary_expr.op(), + binary_expr.right().as_any(), + ); + + // col literal + if let (Some(col), Some(lit)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { col, op: *op, lit }) + } + // literal col + else if let (Some(lit), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + // Used swapped operator operator, if possible + op.swap().map(|op| Self { col, op, lit }) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::create_physical_expr; + use crate::execution_props::ExecutionProps; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; + use datafusion_expr::expr_fn::*; + use datafusion_expr::{lit, Expr}; + use std::sync::OnceLock; + + #[test] + fn test_literal() { + // a single literal offers no guarantee + test_analyze(lit(true), vec![]) + } + + #[test] + fn test_single() { + // a = "foo" + test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]); + // "foo" = a + test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]); + // a != "foo" + test_analyze( + col("a").not_eq(lit("foo")), + vec![not_in_guarantee("a", ["foo"])], + ); + // "foo" != a + test_analyze( + lit("foo").not_eq(col("a")), + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_single_column() { + // b = 1 AND b = 2. This is impossible. Ideally this expression could be simplified to false + test_analyze(col("b").eq(lit(1)).and(col("b").eq(lit(2))), vec![]); + // b = 1 AND b != 2 . In theory, this could be simplified to `b = 1`. + test_analyze( + col("b").eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![ + // can only be true of b is 1 and b is not 2 (even though it is redundant) + in_guarantee("b", [1]), + not_in_guarantee("b", [2]), + ], + ); + // b != 1 AND b = 2. In theory, this could be simplified to `b = 2`. + test_analyze( + col("b").not_eq(lit(1)).and(col("b").eq(lit(2))), + vec![ + // can only be true of b is not 1 and b is is 2 (even though it is redundant) + not_in_guarantee("b", [1]), + in_guarantee("b", [2]), + ], + ); + // b != 1 AND b != 2 + test_analyze( + col("b").not_eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2])], + ); + // b != 1 AND b != 2 and b != 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + // b != 1 AND b = 2 and b != 3. Can only be true if b is 2 and b is not in (1, 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 3]), in_guarantee("b", [2])], + ); + // b != 1 AND b != 2 and b = 3 (in theory could determine b = 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").eq(lit(3))), + vec![not_in_guarantee("b", [1, 2]), in_guarantee("b", [3])], + ); + // b != 1 AND b != 2 and b > 3 (to be true, b can't be either 1 or 2 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").gt(lit(3))), + vec![not_in_guarantee("b", [1, 2])], + ); + } + + #[test] + fn test_conjunction_multi_column() { + // a = "foo" AND b = 1 + test_analyze( + col("a").eq(lit("foo")).and(col("b").eq(lit(1))), + vec![ + // should find both column guarantees + in_guarantee("a", ["foo"]), + in_guarantee("b", [1]), + ], + ); + // a != "foo" AND b != 1 + test_analyze( + col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))), + // should find both column guarantees + vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])], + ); + // a = "foo" AND a = "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))), + // this predicate is impossible ( can't be both foo and bar), + vec![], + ); + // a = "foo" AND b != "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])], + ); + // a != "foo" AND a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + // know it isn't "foo" or "bar" + vec![not_in_guarantee("a", ["foo", "bar"])], + ); + // a != "foo" AND a != "bar" and a != "baz" + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar"))) + .and(col("a").not_eq(lit("baz"))), + // know it isn't "foo" or "bar" or "baz" + vec![not_in_guarantee("a", ["foo", "bar", "baz"])], + ); + // a = "foo" AND a = "foo" + let expr = col("a").eq(lit("foo")); + test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]); + // b > 5 AND b = 10 (should get an b = 10 guarantee) + test_analyze( + col("b").gt(lit(5)).and(col("b").eq(lit(10))), + vec![in_guarantee("b", [10])], + ); + // b > 10 AND b = 10 (this is impossible) + test_analyze( + col("b").gt(lit(10)).and(col("b").eq(lit(10))), + vec![ + // if b isn't 10, it can not be true (though the expression actually can never be true) + in_guarantee("b", [10]), + ], + ); + // a != "foo" and (a != "bar" OR a != "baz") + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))), + // a is not foo (we can't represent other knowledge about a) + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_and_disjunction_single_column() { + // b != 1 AND (b > 2) + test_analyze( + col("b").not_eq(lit(1)).and(col("b").gt(lit(2))), + vec![ + // for the expression to be true, b can not be one + not_in_guarantee("b", [1]), + ], + ); + + // b = 1 AND (b = 2 OR b = 3). Could be simplified to false. + test_analyze( + col("b") + .eq(lit(1)) + .and(col("b").eq(lit(2)).or(col("b").eq(lit(3)))), + vec![ + // in theory, b must be 1 and one of 2,3 for this expression to be true + // which is a logical contradiction + ], + ); + } + + #[test] + fn test_disjunction_single_column() { + // b = 1 OR b = 2 + test_analyze( + col("b").eq(lit(1)).or(col("b").eq(lit(2))), + vec![in_guarantee("b", [1, 2])], + ); + // b != 1 OR b = 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").eq(lit(2))), vec![]); + // b = 1 OR b != 2 + test_analyze(col("b").eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 OR b = 3 -- in theory could guarantee that b = 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .or(col("b").not_eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + // b = 1 OR b = 2 OR b = 3 + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b = 1 OR b = 2 OR b > 3 -- can't guarantee that the expression is only true if a is in (1, 2) + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + } + + #[test] + fn test_disjunction_multi_column() { + // a = "foo" OR b = 1 + test_analyze( + col("a").eq(lit("foo")).or(col("b").eq(lit(1))), + // no can't have a single column guarantee (if a = "foo" then b != 1) etc + vec![], + ); + // a != "foo" OR b != 1 + test_analyze( + col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))), + // No single column guarantee + vec![], + ); + // a = "foo" OR a = "bar" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // a = "foo" OR a = "foo" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))), + vec![in_guarantee("a", ["foo"])], + ); + // a != "foo" OR a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))), + // can't represent knowledge about a in this case + vec![], + ); + // a = "foo" OR a = "bar" OR a = "baz" + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("a").eq(lit("baz"))), + vec![in_guarantee("a", ["foo", "bar", "baz"])], + ); + // (a = "foo" OR a = "bar") AND (a = "baz)" + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("a").eq(lit("baz"))), + // this could potentially be represented as 2 constraints with a more + // sophisticated analysis + vec![], + ); + // (a = "foo" OR a = "bar") AND (b = 1) + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("b").eq(lit(1))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])], + ); + // (a = "foo" OR a = "bar") OR (b = 1) + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("b").eq(lit(1))), + // can't represent knowledge about a or b in this case + vec![], + ); + } + + // TODO https://github.com/apache/arrow-datafusion/issues/8436 + // a IN (...) + // b NOT IN (...) + + /// Tests that analyzing expr results in the expected guarantees + fn test_analyze(expr: Expr, expected: Vec) { + println!("Begin analyze of {expr}"); + let schema = schema(); + let physical_expr = logical2physical(&expr, &schema); + + let actual = LiteralGuarantee::analyze(&physical_expr); + assert_eq!( + expected, actual, + "expr: {expr}\ + \n\nexpected: {expected:#?}\ + \n\nactual: {actual:#?}\ + \n\nexpr: {expr:#?}\ + \n\nphysical_expr: {physical_expr:#?}" + ); + } + + /// Guarantee that the expression is true if the column is one of the specified values + fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Operator::Eq, literals.iter()).unwrap() + } + + /// Guarantee that the expression is true if the column is NOT any of the specified values + fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Operator::NotEq, literals.iter()).unwrap() + } + + /// Convert a logical expression to a physical expression (without any simplification, etc) + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } + + // Schema for testing + fn schema() -> SchemaRef { + SCHEMA + .get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + }) + .clone() + } + + static SCHEMA: OnceLock = OnceLock::new(); +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils/mod.rs similarity index 96% rename from datafusion/physical-expr/src/utils.rs rename to datafusion/physical-expr/src/utils/mod.rs index 71a7ff5fb778..87ef36558b96 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +mod guarantee; +pub use guarantee::{Guarantee, LiteralGuarantee}; + use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -41,25 +44,29 @@ use petgraph::stable_graph::StableGraph; pub fn split_conjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_conjunction_impl(predicate, vec![]) + split_impl(Operator::And, predicate, vec![]) } -fn split_conjunction_impl<'a>( +/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_disjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::Or, predicate, vec![]) +} + +fn split_impl<'a>( + operator: Operator, predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { match predicate.as_any().downcast_ref::() { - Some(binary) => match binary.op() { - Operator::And => { - let exprs = split_conjunction_impl(binary.left(), exprs); - split_conjunction_impl(binary.right(), exprs) - } - _ => { - exprs.push(predicate); - exprs - } - }, - None => { + Some(binary) if binary.op() == &operator => { + let exprs = split_impl(operator, binary.left(), exprs); + split_impl(operator, binary.right(), exprs) + } + Some(_) | None => { exprs.push(predicate); exprs } From 65b997bc465fe6b9dc6692deebbd2d72da189702 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:11:56 +0300 Subject: [PATCH 20/31] [MINOR]: Parametrize sort-preservation tests to exercise all situations (unbounded/bounded sources and flag behavior) (#8575) * Re-introduce unbounded tests with new executor * Remove unnecessary test --- .../src/physical_optimizer/enforce_sorting.rs | 19 +- .../replace_with_order_preserving_variants.rs | 275 +++++++++++------- datafusion/core/src/test/mod.rs | 36 +++ 3 files changed, 208 insertions(+), 122 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index c0e9b834e66f..2b650a42696b 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -769,7 +769,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::csv_exec_sorted; + use crate::test::{csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -2141,11 +2141,11 @@ mod tests { } #[tokio::test] - #[ignore] async fn test_with_lost_ordering_unbounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + // create an unbounded source + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2159,25 +2159,24 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] - #[ignore] async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - // Make source unbounded - let source = csv_exec_sorted(&schema, sort_exprs); + // create an unbounded source + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2190,13 +2189,13 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 41f2b39978a4..671891be433c 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -276,9 +276,6 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; - use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::filter::FilterExec; @@ -289,14 +286,16 @@ mod tests { use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; + use crate::test::TestStreamPartition; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::{Result, Statistics}; - use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_common::Result; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::streaming::StreamingTableExec; + use rstest::rstest; /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan /// against the original and expected plans. @@ -345,12 +344,15 @@ mod tests { }; } + #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected - async fn test_replace_multiple_input_repartition_1() -> Result<()> { + async fn test_replace_multiple_input_repartition_1( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -362,23 +364,31 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_inter_children_change_only() -> Result<()> { + async fn test_with_inter_children_change_only( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +418,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ @@ -419,17 +429,25 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_2( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -444,24 +462,32 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -478,7 +504,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -486,17 +512,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps_2( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -516,7 +550,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -525,17 +559,25 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { + async fn test_not_replacing_when_no_need_to_preserve_sorting( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -550,7 +592,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "CoalescePartitionsExec", @@ -558,17 +600,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_replacable_repartitions() -> Result<()> { + async fn test_with_multiple_replacable_repartitions( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -587,7 +637,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -596,17 +646,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replace_with_different_orderings() -> Result<()> { + async fn test_not_replace_with_different_orderings( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -625,24 +683,32 @@ mod tests { " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering() -> Result<()> { + async fn test_with_lost_ordering( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -654,23 +720,31 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_and_kept_ordering() -> Result<()> { + async fn test_with_lost_and_kept_ordering( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -700,7 +774,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ @@ -712,25 +786,33 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_child_trees() -> Result<()> { + async fn test_with_multiple_child_trees( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs); + let left_source = stream_exec_ordered(&schema, left_sort_exprs); let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs); + let right_source = stream_exec_ordered(&schema, right_sort_exprs); let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -756,11 +838,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ @@ -770,41 +852,18 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); - Ok(()) - } - - #[tokio::test] - async fn test_with_bounded_input() -> Result<()> { - let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); - let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - let expected_input = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", - ]; - let expected_optimized = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } @@ -928,32 +987,24 @@ mod tests { // creates a csv exec source for the test purposes // projection and has_header parameters are given static due to testing needs - fn csv_exec_sorted( + fn stream_exec_ordered( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new( - "file_path".to_string(), - 100, - )]], - statistics: Statistics::new_unknown(schema), - projection: Some(projection), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - true, - 0, - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + Some(&projection), + vec![sort_exprs], + true, + ) + .unwrap(), + ) } } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 8770c0c4238a..7a63466a3906 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -49,6 +49,7 @@ use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] @@ -298,6 +299,41 @@ pub fn csv_exec_sorted( )) } +// construct a stream partition for test purposes +pub(crate) struct TestStreamPartition { + pub schema: SchemaRef, +} + +impl PartitionStream for TestStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } +} + +/// Create an unbounded stream exec +pub fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + None, + vec![sort_exprs], + true, + ) + .unwrap(), + ) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { From fc46b36a4078a7fdababfc2d3735e83caf1326f7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Dec 2023 14:27:32 -0500 Subject: [PATCH 21/31] Minor: Add some comments to scalar_udf example (#8576) * refine example * clippy --- datafusion-examples/examples/simple_udf.rs | 30 +++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index dba4385b8eea..591991786515 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -29,23 +29,23 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; use std::sync::Arc; -// create local execution context with an in-memory table +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ])); - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), - Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - )?; + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; // declare a new context. In spark API, this corresponds to a new spark SQLsession let ctx = SessionContext::new(); From 1935c58f5cffe123839bae4e9d77a128351728e1 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 19 Dec 2023 04:17:47 +0800 Subject: [PATCH 22/31] Move Coercion for MakeArray to `coerce_arguments_for_signature` and introduce another one for ArrayAppend (#8317) * Signature for array_append and make_array Signed-off-by: jayzhan211 * combine variadicequal and coerced to equal Signed-off-by: jayzhan211 * follow postgres style on array_append(null, T) Signed-off-by: jayzhan211 * update comment for ArrayAndElement Signed-off-by: jayzhan211 * remove test Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/common/src/utils.rs | 38 ++++++++++ datafusion/expr/src/built_in_function.rs | 13 ++-- datafusion/expr/src/signature.rs | 20 ++++- .../expr/src/type_coercion/functions.rs | 74 +++++++++++++++++-- .../optimizer/src/analyzer/type_coercion.rs | 34 --------- .../physical-expr/src/array_expressions.rs | 25 ++----- datafusion/sqllogictest/test_files/array.slt | 51 +++++++++---- 7 files changed, 176 insertions(+), 79 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index fecab8835e50..2d38ca21829b 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -342,6 +342,8 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` pub fn array_into_list_array(arr: ArrayRef) -> ListArray { @@ -429,6 +431,42 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { if let DataType::List(field) = data_type { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fd899289ac82..289704ed98f8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -915,10 +915,17 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArraySort => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAndElement, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -958,10 +965,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - // 0 or more arguments of arbitrary type - Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) - } BuiltinScalarFunction::Range => Signature::one_of( vec![ Exact(vec![Int64]), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 685601523f9b..3f07c300e196 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -91,11 +91,14 @@ pub enum TypeSignature { /// DataFusion attempts to coerce all argument types to match the first argument's type /// /// # Examples - /// A function such as `array` is `VariadicEqual` + /// Given types in signature should be coericible to the same final type. + /// A function such as `make_array` is `VariadicEqual`. + /// + /// `make_array(i32, i64) -> make_array(i64, i64)` VariadicEqual, /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` @@ -113,6 +116,12 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + /// The first argument should be List/LargeList, and the second argument should be non-list or list. + /// The second argument's list dimension should be one dimension less than the first argument's list dimension. + /// List dimension of the List/LargeList is equivalent to the number of List. + /// List dimension of the non-list is 0. + ArrayAndElement, } impl TypeSignature { @@ -136,11 +145,16 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAndElement => { + vec!["ArrayAndElement(List, T)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 79b574238495..f95a30e025b4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,7 +21,10 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -86,16 +89,66 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAndElement => { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let array_type = ¤t_types[0]; + let elem_type = ¤t_types[1]; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); + } else { + return Ok(vec![vec![]]); + } + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -241,6 +294,15 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + // Only accept list with the same number of dimensions unless the type is Null. + // List with different dimensions should be handled in TypeSignature or other places before this. + List(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => + { + Some(type_into.clone()) + } + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd..c5e1180b9f97 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7ccf58af832d..98c9aee8940f 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -361,7 +361,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: DataType::Null => { - let array = new_null_array(&DataType::Null, arrays.len()); + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } DataType::LargeList(..) => array_array::(arrays, data_type), @@ -827,10 +828,14 @@ pub fn array_append(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let element_array = &args[1]; - check_datatypes("array_append", &[list_array.values(), element_array])?; let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { return general_append_and_prepend( list_array, @@ -2284,18 +2289,4 @@ mod tests { expected_dim ); } - - #[test] - fn test_check_invalid_datatypes() { - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - - let args = [list_array.clone(), int64_array.clone()]; - - let array = array_append(&args); - - assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 210739aa51da..640f5064eae6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -297,10 +297,8 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; -query ? +query error select [1, true, null] ----- -[1, 1, ] query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() SELECT [now()] @@ -1253,18 +1251,43 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); -# ---- -# [4] +# array_append with NULLs -# array_append scalar function #2 -# query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); -# ---- -# [[]] [[4]] +query error +select array_append(null, 1); + +query error +select array_append(null, [2, 3]); + +query error +select array_append(null, [[4]]); + +query ???? +select + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] + +# test invalid (non-null) +query error +select array_append(1, 2); + +query error +select array_append(1, [2]); + +query error +select array_append([1], [2]); + +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ??? From d220bf47f944dd019d6b1e5b2741535a3f90204f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 18 Dec 2023 21:22:34 +0100 Subject: [PATCH 23/31] support LargeList in array_positions (#8571) --- .../physical-expr/src/array_expressions.rs | 19 ++++++-- datafusion/sqllogictest/test_files/array.slt | 43 +++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 98c9aee8940f..cc4b2899fcb1 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1289,12 +1289,23 @@ fn general_position( /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; let element = &args[1]; - check_datatypes("array_positions", &[arr.values(), element])?; - - general_positions::(arr, element) + match &args[0].data_type() { + DataType::List(_) => { + let arr = as_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + DataType::LargeList(_) => { + let arr = as_large_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + array_type => { + not_impl_err!("array_positions does not support type '{array_type:?}'.") + } + } } fn general_positions( diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 640f5064eae6..d148f7118176 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1832,18 +1832,33 @@ select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3 ---- [3, 4] [5] [1, 2, 3] +query ??? +select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions scalar function #2 (element is list) query ? select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); ---- [2, 4] +query ? +select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), 'LargeList(List(Int64))'), [2, 1, 3]); +---- +[2, 4] + # list_positions scalar function #3 (function alias `array_positions`) query ??? select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] +query ??? +select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; @@ -1853,6 +1868,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] +query ? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + # array_positions with columns #2 (element is list) query ? select array_positions(column1, column2) from nested_arrays; @@ -1860,6 +1883,12 @@ select array_positions(column1, column2) from nested_arrays; [3] [2, 5] +query ? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[3] +[2, 5] + # array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; @@ -1869,6 +1898,14 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + # array_positions with columns and scalars #2 (element is list) query ?? select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; @@ -1876,6 +1913,12 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array [6] [] [1] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(4, 5, 6)), array_positions(arrow_cast(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace (aliases: `list_replace`) # array_replace scalar function #1 From d33ca4dd37b8b47120579b7c3e0456c1fcbcb06f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 18 Dec 2023 21:26:02 +0100 Subject: [PATCH 24/31] support LargeList in array_element (#8570) --- datafusion/expr/src/built_in_function.rs | 3 +- .../physical-expr/src/array_expressions.rs | 82 +++++++++++++------ datafusion/sqllogictest/test_files/array.slt | 72 +++++++++++++++- 3 files changed, 130 insertions(+), 27 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 289704ed98f8..3818e8ee5658 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -591,8 +591,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), + LargeList(field) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list as the first argument" + "The {self} function can only accept list or largelist as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index cc4b2899fcb1..d39658108337 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -370,18 +370,14 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } -/// array_element SQL function -/// -/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. -/// `array_element(array, index)` -/// -/// For example: -/// > array_element(\[1, 2, 3], 2) -> 2 -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; - - let values = list_array.values(); +fn general_array_element( + array: &GenericListArray, + indexes: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -389,37 +385,47 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity); - fn adjusted_array_index(index: i64, len: usize) -> Option { + fn adjusted_array_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + let index: O = index.try_into().map_err(|_| { + DataFusionError::Execution(format!( + "array_element got invalid index: {}", + index + )) + })?; // 0 ~ len - 1 - let adjusted_zero_index = if index < 0 { - index + len as i64 + let adjusted_zero_index = if index < O::usize_as(0) { + index + len } else { - index - 1 + index - O::usize_as(1) }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // array is null - if len == 0 { + if len == O::usize_as(0) { mutable.extend_nulls(1); continue; } - let index = adjusted_array_index(indexes.value(row_index), len); + let index = adjusted_array_index::(indexes.value(row_index), len)?; if let Some(index) = index { - mutable.extend(0, start + index as usize, start + index as usize + 1); + let start = start.as_usize() + index.as_usize(); + mutable.extend(0, start, start + 1_usize); } else { // Index out of bounds mutable.extend_nulls(1); @@ -430,6 +436,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result { Ok(arrow_array::make_array(data)) } +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + _ => not_impl_err!( + "array_element does not support type: {:?}", + args[0].data_type() + ), + } +} + fn general_except( l: &GenericListArray, r: &GenericListArray, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d148f7118176..b38f73ecb8db 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -717,7 +717,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument select array_element(1, 2); @@ -727,58 +727,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h' ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + # array_element scalar function #5 (with negative index) query IT select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_slice`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_slice`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_slice`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element with columns query I select array_element(column1, column2) from slices; @@ -791,6 +839,17 @@ NULL NULL 55 +query I +select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -803,6 +862,17 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1 From 9bc61b31ae4f67c55c03214c9b807079e4fe0f44 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:22:33 +0300 Subject: [PATCH 25/31] Increase test coverage for unbounded and bounded cases (#8581) * Re-introduce unbounded tests with new executor * Remove unnecessary test * Enhance test coverage * Review * Test passes * Change argument order * Parametrize enforce sorting test * Imports --------- Co-authored-by: Mehmet Ozan Kabak --- .../src/physical_optimizer/enforce_sorting.rs | 92 ++- .../replace_with_order_preserving_variants.rs | 714 +++++++++++++++--- datafusion/core/src/test/mod.rs | 28 +- 3 files changed, 697 insertions(+), 137 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2b650a42696b..2ecc1e11b985 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -60,8 +60,8 @@ use crate::physical_plan::{ use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; - use datafusion_physical_plan::repartition::RepartitionExec; + use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -769,7 +769,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::{csv_exec_sorted, stream_exec_ordered}; + use crate::test::{csv_exec_ordered, csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -777,6 +777,8 @@ mod tests { use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + use rstest::rstest; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -2140,12 +2142,19 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering_unbounded() -> Result<()> { + async fn test_with_lost_ordering_unbounded_bounded( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - // create an unbounded source - let source = stream_exec_ordered(&schema, sort_exprs); + // create either bounded or unbounded source + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_ordered(&schema, sort_exprs) + }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2154,50 +2163,71 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = vec![ "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - let expected_optimized = [ + let expected_input_bounded = vec![ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = vec![ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) - } - #[tokio::test] - async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - // create an unbounded source - let source = stream_exec_ordered(&schema, sort_exprs); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - let expected_input = ["SortExec: expr=[a@0 ASC]", + // Expected bounded results with and without flag + let expected_optimized_bounded = vec![ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - let expected_optimized = [ + let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = + if source_unbounded { + ( + expected_input_unbounded, + expected_optimized_unbounded.clone(), + expected_optimized_unbounded, + ) + } else { + ( + expected_input_bounded, + expected_optimized_bounded, + expected_optimized_bounded_parallelize_sort, + ) + }; + assert_optimized!( + expected_input, + expected_optimized, + physical_plan.clone(), + false + ); + assert_optimized!( + expected_input, + expected_optimized_sort_parallelize, + physical_plan, + true + ); Ok(()) } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 671891be433c..0ff7e9f48edc 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -276,6 +276,9 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::filter::FilterExec; @@ -285,35 +288,95 @@ mod tests { use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; - use crate::test::TestStreamPartition; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::Result; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::streaming::StreamingTableExec; + use rstest::rstest; - /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan - /// against the original and expected plans. + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans for both bounded and + /// unbounded cases. /// - /// `$EXPECTED_PLAN_LINES`: input plan - /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan - /// `$PLAN`: the plan to optimized - /// `$ALLOW_BOUNDED`: whether to allow the plan to be optimized for bounded cases - macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + /// # Parameters + /// + /// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. + /// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. + /// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is + /// the same regardless of the value of the `prefer_existing_sort` flag. + /// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false` for bounded cases. + /// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan + /// when the flag `prefer_existing_sort` is `true` for bounded cases. + /// * `$PLAN`: The plan to optimize. + /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. + macro_rules! assert_optimized_in_all_boundedness_situations { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + if $SOURCE_UNBOUNDED { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_UNBOUNDED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } else { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_BOUNDED_PLAN_LINES, + $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } + }; + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false`. + /// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when + /// the flag `prefer_existing_sort` is `true`. + /// * `$PLAN`: The plan to optimize. + macro_rules! assert_optimized_prefer_sort_on_off { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { assert_optimized!( $EXPECTED_PLAN_LINES, $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, + $PLAN.clone(), false ); + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + true + ); }; - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $ALLOW_BOUNDED: expr) => { + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. + /// * `$PLAN`: The plan to optimize. + /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -329,8 +392,7 @@ mod tests { let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); // Run the rule top-down - // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; - let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); + let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; @@ -348,35 +410,67 @@ mod tests { #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -384,11 +478,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_inter_children_change_only( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +506,8 @@ mod tests { sort2, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " FilterExec: c@1 > 3", @@ -420,8 +519,21 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", @@ -431,11 +543,38 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -443,11 +582,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_2( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -456,7 +599,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -464,18 +608,48 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -483,11 +657,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -497,7 +675,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", @@ -506,7 +685,18 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -514,11 +704,33 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -526,11 +738,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -542,7 +758,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", @@ -552,7 +769,19 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -561,11 +790,35 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -573,11 +826,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -586,7 +843,8 @@ mod tests { let physical_plan: Arc = coalesce_partitions_exec(coalesce_batches_exec); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -594,7 +852,17 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -602,11 +870,26 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + let expected_optimized_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -614,11 +897,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_multiple_replacable_repartitions( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -629,7 +916,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -639,7 +927,19 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", @@ -648,11 +948,35 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -660,11 +984,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_not_replace_with_different_orderings( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -678,25 +1006,49 @@ mod tests { sort, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -704,35 +1056,67 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_lost_ordering( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -740,11 +1124,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_lost_and_kept_ordering( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -764,7 +1152,8 @@ mod tests { sort2, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " FilterExec: c@1 > 3", @@ -776,8 +1165,21 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", @@ -788,11 +1190,39 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -800,19 +1230,27 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_multiple_child_trees( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = stream_exec_ordered(&schema, left_sort_exprs); + let left_source = if source_unbounded { + stream_exec_ordered(&schema, left_sort_exprs) + } else { + csv_exec_sorted(&schema, left_sort_exprs) + }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = stream_exec_ordered(&schema, right_sort_exprs); + let right_source = if source_unbounded { + stream_exec_ordered(&schema, right_sort_exprs) + } else { + csv_exec_sorted(&schema, right_sort_exprs) + }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -831,7 +1269,8 @@ mod tests { sort, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", @@ -844,8 +1283,22 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", @@ -858,11 +1311,32 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve + // existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -985,8 +1459,7 @@ mod tests { Ok(schema) } - // creates a csv exec source for the test purposes - // projection and has_header parameters are given static due to testing needs + // creates a stream exec source for the test purposes fn stream_exec_ordered( schema: &SchemaRef, sort_exprs: impl IntoIterator, @@ -1007,4 +1480,35 @@ mod tests { .unwrap(), ) } + + // creates a csv exec source for the test purposes + // projection and has_header parameters are given static due to testing needs + fn csv_exec_sorted( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + let projection: Vec = vec![0, 2, 3]; + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new( + "file_path".to_string(), + 100, + )]], + statistics: Statistics::new_unknown(schema), + projection: Some(projection), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) + } } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 7a63466a3906..ed5aa15e291b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -43,13 +43,13 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, FileType, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; #[cfg(feature = "compression")] use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] @@ -334,6 +334,32 @@ pub fn stream_exec_ordered( ) } +/// Create a csv exec for tests +pub fn csv_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("file_path".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { From f041e73b48e426a3679301d3b28c9dc4410a8d97 Mon Sep 17 00:00:00 2001 From: Trevor Hilton Date: Tue, 19 Dec 2023 14:03:50 -0500 Subject: [PATCH 26/31] Port tests in `parquet.rs` to sqllogictest (#8560) * setup parquet.slt and port parquet_query test to it * port parquet_with_sort_order_specified, but missing files * port fixed_size_binary_columns test * port window_fn_timestamp_tz test * port parquet_single_nan_schema test * port parquet_query_with_max_min test * use COPY to create tables in parquet.slt to test partitioning over multi-file data * remove unneeded optimizer setting; check type of timestamp column --- datafusion/core/tests/sql/parquet.rs | 292 ----------------- .../sqllogictest/test_files/parquet.slt | 304 ++++++++++++++++++ 2 files changed, 304 insertions(+), 292 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/parquet.slt diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index 8f810a929df3..f80a28f7e4f9 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -15,207 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::{fs, path::Path}; - -use ::parquet::arrow::ArrowWriter; -use datafusion::{datasource::listing::ListingOptions, execution::options::ReadOptions}; use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; -use tempfile::TempDir; use super::*; -#[tokio::test] -async fn parquet_query() { - let ctx = SessionContext::new(); - register_alltypes_parquet(&ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+---------------------------+", - "| id | alltypes_plain.string_col |", - "+----+---------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -/// Test that if sort order is specified in ListingOptions, the sort -/// expressions make it all the way down to the ParquetExec -async fn parquet_with_sort_order_specified() { - let parquet_read_options = ParquetReadOptions::default(); - let session_config = SessionConfig::new().with_target_partitions(2); - - // The sort order is not specified - let options_no_sort = parquet_read_options.to_listing_options(&session_config); - - // The sort order is specified (not actually correct in this case) - let file_sort_order = [col("string_col"), col("int_col")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>(); - - let options_sort = parquet_read_options - .to_listing_options(&session_config) - .with_file_sort_order(vec![file_sort_order]); - - // This string appears in ParquetExec if the output ordering is - // specified - let expected_output_ordering = - "output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST]"; - - // when sort not specified, should not appear in the explain plan - let num_files = 1; - assert_not_contains!( - run_query_with_options(options_no_sort, num_files).await, - expected_output_ordering - ); - - // when sort IS specified, SHOULD appear in the explain plan - let num_files = 1; - assert_contains!( - run_query_with_options(options_sort.clone(), num_files).await, - expected_output_ordering - ); - - // when sort IS specified, but there are too many files (greater - // than the number of partitions) sort should not appear - let num_files = 3; - assert_not_contains!( - run_query_with_options(options_sort, num_files).await, - expected_output_ordering - ); -} - -/// Runs a limit query against a parquet file that was registered from -/// options on num_files copies of all_types_plain.parquet -async fn run_query_with_options(options: ListingOptions, num_files: usize) -> String { - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - let file_path = format!("{testdata}/alltypes_plain.parquet"); - - // Create a directory of parquet files with names - // 0.parquet - // 1.parquet - let tmpdir = TempDir::new().unwrap(); - for i in 0..num_files { - let target_file = tmpdir.path().join(format!("{i}.parquet")); - println!("Copying {file_path} to {target_file:?}"); - std::fs::copy(&file_path, target_file).unwrap(); - } - - let provided_schema = None; - let sql_definition = None; - ctx.register_listing_table( - "t", - tmpdir.path().to_string_lossy(), - options.clone(), - provided_schema, - sql_definition, - ) - .await - .unwrap(); - - let batches = ctx.sql("explain select int_col, string_col from t order by string_col, int_col limit 10") - .await - .expect("planing worked") - .collect() - .await - .expect("execution worked"); - - arrow::util::pretty::pretty_format_batches(&batches) - .unwrap() - .to_string() -} - -#[tokio::test] -async fn fixed_size_binary_columns() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/test_binary.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT ids FROM t0 ORDER BY ids"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(466, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -async fn window_fn_timestamp_tz() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/timestamp_with_tz.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT count, LAG(timestamp, 1) OVER (ORDER BY timestamp) FROM t0"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - let mut num_rows = 0; - for batch in results { - num_rows += batch.num_rows(); - assert_eq!(2, batch.num_columns()); - - let ty = batch.column(0).data_type().clone(); - assert_eq!(DataType::Int64, ty); - - let ty = batch.column(1).data_type().clone(); - assert_eq!( - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - ty - ); - } - - assert_eq!(131072, num_rows); -} - -#[tokio::test] -async fn parquet_single_nan_schema() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "single_nan", - &format!("{testdata}/single_nan.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - #[tokio::test] #[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] async fn parquet_list_columns() { @@ -286,98 +89,3 @@ async fn parquet_list_columns() { assert_eq!(result.value(2), "hij"); assert_eq!(result.value(3), "xyz"); } - -#[tokio::test] -async fn parquet_query_with_max_min() { - let tmp_dir = TempDir::new().unwrap(); - let table_dir = tmp_dir.path().join("parquet_test"); - let table_path = Path::new(&table_dir); - - let fields = vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), - ]; - - let schema = Arc::new(Schema::new(fields.clone())); - - if let Ok(()) = fs::create_dir(table_path) { - let filename = "foo.parquet"; - let path = table_path.join(filename); - let file = fs::File::create(path).unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) - .unwrap(); - - // create mock record batch - let c1s = Arc::new(Int32Array::from(vec![1, 2, 3])); - let c2s = Arc::new(StringArray::from(vec!["aaa", "bbb", "ccc"])); - let c3s = Arc::new(Int64Array::from(vec![100, 200, 300])); - let c4s = Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![c1s, c2s, c3s, c4s]).unwrap(); - - writer.write(&rec_batch).unwrap(); - writer.close().unwrap(); - } - - // query parquet - let ctx = SessionContext::new(); - - ctx.register_parquet( - "foo", - &format!("{}/foo.parquet", table_dir.to_str().unwrap()), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT max(c1) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c1) |", - "+-------------+", - "| 3 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c2) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c2) |", - "+-------------+", - "| aaa |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT max(c3) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c3) |", - "+-------------+", - "| 300 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c4) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c4) |", - "+-------------+", - "| 1970-01-02 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); -} diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt new file mode 100644 index 000000000000..bbe7f33e260c --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TESTS FOR PARQUET FILES + +# Set 2 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 2; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 2), +(3, 'ccc', 300, 3), +(4, 'ddd', 400, 4), +(5, 'eee', 500, 5), +(6, 'fff', 600, 6), +(7, 'ggg', 700, 7), +(8, 'hhh', 800, 8), +(9, 'iii', 900, 9); + +# Setup 2 files, i.e., as many as there are partitions: + +# File 1: +query ITID +COPY (SELECT * FROM src_table LIMIT 3) +TO 'test_files/scratch/parquet/test_table/0.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# File 2: +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Create a table from generated parquet files, without ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table'; + +# Basic query: +query ITID +SELECT * FROM test_table ORDER BY int_col; +---- +1 aaa 100 1970-01-02 +2 bbb 200 1970-01-03 +3 ccc 300 1970-01-04 +4 ddd 400 1970-01-05 +5 eee 500 1970-01-06 +6 fff 600 1970-01-07 + +# Check output plan, expect no "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col] + +# Tear down test_table: +statement ok +DROP TABLE test_table; + +# Create test_table again, but with ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet/test_table'; + +# Check output plan, expect an "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] + +# Add another file to the directory underlying test_table +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec, +# due to there being more files than partitions: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col] + + +# Perform queries using MIN and MAX +query I +SELECT max(int_col) FROM test_table; +---- +9 + +query T +SELECT min(string_col) FROM test_table; +---- +aaa + +query I +SELECT max(bigint_col) FROM test_table; +---- +900 + +query D +SELECT min(date_col) FROM test_table; +---- +1970-01-02 + +# Clean up +statement ok +DROP TABLE test_table; + +# Setup alltypes_plain table: +statement ok +CREATE EXTERNAL TABLE alltypes_plain ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/alltypes_plain.parquet' + +# Test a basic query with a CAST: +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# Clean up +statement ok +DROP TABLE alltypes_plain; + +# Perform SELECT on table with fixed sized binary columns + +statement ok +CREATE EXTERNAL TABLE test_binary +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/test_binary.parquet'; + +# Check size of table: +query I +SELECT count(ids) FROM test_binary; +---- +466 + +# Do the SELECT query: +query ? +SELECT ids FROM test_binary ORDER BY ids LIMIT 10; +---- +008c7196f68089ab692e4739c5fd16b5 +00a51a7bc5ff8eb1627f8f3dc959dce8 +0166ce1d46129ad104fa4990c6057c91 +03a4893f3285b422820b4cd74c9b9786 +04999ac861e14682cd339eae2cc74359 +04b86bf8f228739fde391f850636a77d +050fb9cf722a709eb94b70b3ee7dc342 +052578a65e8e91b8526b182d40e846e8 +05408e6a403e4296526006e20cc4a45a +0592e6fb7d7169b888a4029b53abb701 + +# Clean up +statement ok +DROP TABLE test_binary; + +# Perform a query with a window function and timestamp data: + +statement ok +CREATE EXTERNAL TABLE timestamp_with_tz +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/timestamp_with_tz.parquet'; + +# Check size of table: +query I +SELECT COUNT(*) FROM timestamp_with_tz; +---- +131072 + +# Perform the query: +query IPT +SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp), + arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) +FROM timestamp_with_tz +LIMIT 10; +---- +0 NULL Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) + +# Clean up +statement ok +DROP TABLE timestamp_with_tz; + +# Test a query from the single_nan data set: +statement ok +CREATE EXTERNAL TABLE single_nan +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/single_nan.parquet'; + +# Check table size: +query I +SELECT COUNT(*) FROM single_nan; +---- +1 + +# Query for the single NULL: +query R +SELECT mycol FROM single_nan; +---- +NULL + +# Clean up +statement ok +DROP TABLE single_nan; From b456cf78db87bd1369b79a7eec4e3764f551982d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 19 Dec 2023 14:56:34 -0500 Subject: [PATCH 27/31] Minor: avoid a copy in Expr::unalias (#8588) --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0aab95b8f0d..b46e9ec8f69d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -956,7 +956,7 @@ impl Expr { /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), + Expr::Alias(alias) => *alias.expr, _ => self, } } From 1bcaac4835457627d881f755a87dbd140ec3388c Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Wed, 20 Dec 2023 10:11:29 +0800 Subject: [PATCH 28/31] Minor: support complex expr as the arg in the ApproxPercentileCont function (#8580) * support complex lit expr for the arg * enchancement the percentile --- .../tests/dataframe/dataframe_functions.rs | 20 +++++++++ .../src/aggregate/approx_percentile_cont.rs | 45 +++++++++---------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 9677003ec226..fe56fc22ea8c 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast}; async fn create_test_table() -> Result { @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // the arg2 parameter is a complex expr, but it can be evaluated to the literal value + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.5), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b"), alias_expr); + let df = create_test_table().await?; + let expected = [ + "+--------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "+--------------------------------------+", + "| 10 |", + "+--------------------------------------+", + ]; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index aa4749f64ae9..15c0fb3ace4d 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; +use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::{ @@ -27,11 +27,13 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use arrow_array::RecordBatch; +use arrow_schema::Schema; use datafusion_common::{ downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, ColumnarValue}; use std::{any::Any, iter, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression @@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont { } } +fn get_lit_value(expr: &Arc) -> Result { + let empty_schema = Schema::empty(); + let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + fn validate_input_percentile_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let percentile = match lit { + let lit = get_lit_value(expr)?; + let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, got => return not_impl_err!( @@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { } fn validate_input_max_size_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let max_size = match lit { + let lit = get_lit_value(expr)?; + let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, ScalarValue::UInt32(Some(q)) => *q as usize, From 6f5230ffc77ec0151a7aa870808d2fb31e6146c7 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 20 Dec 2023 10:58:49 +0300 Subject: [PATCH 29/31] Bugfix: Add functional dependency check and aggregate try_new schema (#8584) * Add functional dependency check and aggregate try_new schema * Update comments, make implementation idiomatic * Use constraint during stream table initialization --- datafusion/common/src/dfschema.rs | 16 ++++ datafusion/core/src/datasource/stream.rs | 3 +- datafusion/expr/src/utils.rs | 13 +-- .../physical-plan/src/aggregates/mod.rs | 92 ++++++++++++++++++- .../sqllogictest/test_files/groupby.slt | 12 +++ 5 files changed, 125 insertions(+), 11 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index e06f947ad5e7..d6e4490cec4c 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -347,6 +347,22 @@ impl DFSchema { .collect() } + /// Find all fields indices having the given qualifier + pub fn fields_indices_with_qualified( + &self, + qualifier: &TableReference, + ) -> Vec { + self.fields + .iter() + .enumerate() + .filter_map(|(idx, field)| { + field + .qualifier() + .and_then(|q| q.eq(qualifier).then_some(idx)) + }) + .collect() + } + /// Find all fields match the given name pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { self.fields diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b9b45a6c7470..830cd7a07e46 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory { .with_encoding(encoding) .with_order(cmd.order_exprs.clone()) .with_header(cmd.has_header) - .with_batch_size(state.config().batch_size()); + .with_batch_size(state.config().batch_size()) + .with_constraints(cmd.constraints.clone()); Ok(Arc::new(StreamTable(Arc::new(config)))) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index abdd7f5f57f6..09f4842c9e64 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -32,6 +32,7 @@ use crate::{ use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard( wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); - let qualified_fields: Vec = schema - .fields_with_qualified(&qualifier) - .into_iter() - .cloned() - .collect(); + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let projected_func_dependencies = schema + .functional_dependencies() + .project_functional_dependencies(&qualified_indices, qualified_indices.len()); + let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?; if qualified_fields.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone())?; + .with_functional_dependencies(projected_func_dependencies)?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c74c4ac0f821..921de96252f0 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -43,7 +43,7 @@ use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, - equivalence::collapse_lex_req, + equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -59,7 +59,6 @@ mod topk; mod topk_stream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::equivalence::ProjectionMapping; pub use datafusion_physical_expr::expressions::create_aggregate_expr; /// Hash aggregate modes @@ -464,7 +463,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -482,6 +481,37 @@ impl AggregateExec { group_by.expr.len(), )); let original_schema = Arc::new(original_schema); + AggregateExec::try_new_with_schema( + mode, + group_by, + aggr_expr, + filter_expr, + input, + input_schema, + schema, + original_schema, + ) + } + + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. + #[allow(clippy::too_many_arguments)] + fn try_new_with_schema( + mode: AggregateMode, + group_by: PhysicalGroupBy, + mut aggr_expr: Vec>, + filter_expr: Vec>>, + input: Arc, + input_schema: SchemaRef, + schema: SchemaRef, + original_schema: SchemaRef, + ) -> Result { // Reset ordering requirement to `None` if aggregator is not order-sensitive let mut order_by_expr = aggr_expr .iter() @@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - let mut me = AggregateExec::try_new( + let mut me = AggregateExec::try_new_with_schema( self.mode, self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.schema.clone(), + self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -2162,4 +2194,56 @@ mod tests { assert_eq!(res, common_requirement); Ok(()) } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 44d30ba0b34c..f1b6a57287b5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -4280,3 +4280,15 @@ LIMIT 5 2 0 0 3 0 0 4 0 1 + + +query ITIPTR rowsort +SELECT r.* +FROM sales_global_with_pk as l, sales_global_with_pk as r +LIMIT 5 +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 TUR 4 2022-01-03T10:00:00 TRY 100 From 8d72196f957147335b3828f44153277126eb3c0f Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 20 Dec 2023 17:03:57 +0300 Subject: [PATCH 30/31] Remove GroupByOrderMode (#8593) --- .../physical-plan/src/aggregates/mod.rs | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 921de96252f0..f779322456ca 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -101,34 +101,6 @@ impl AggregateMode { } } -/// Group By expression modes -/// -/// `PartiallyOrdered` and `FullyOrdered` are used to reason about -/// when certain group by keys will never again be seen (and thus can -/// be emitted by the grouping operator). -/// -/// Specifically, each distinct combination of the relevant columns -/// are contiguous in the input, and once a new combination is seen -/// previous combinations are guaranteed never to appear again -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum GroupByOrderMode { - /// The input is known to be ordered by a preset (prefix but - /// possibly reordered) of the expressions in the `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c` and we group - /// by `b, a, d`, `PartiallyOrdered` means a subset of group `b, - /// a, d` defines a preset for the existing ordering, in this case - /// `a, b`. - PartiallyOrdered, - /// The input is known to be ordered by *all* the expressions in the - /// `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c, d` and we group by b, a, - /// `Ordered` means that all of the of group by expressions appear - /// as a preset for the existing ordering, in this case `a, b`. - FullyOrdered, -} - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. From b925b78fd8040f858168e439eda5042bd2a34af6 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 20 Dec 2023 18:18:56 +0100 Subject: [PATCH 31/31] replace not-impl-err (#8589) --- datafusion/physical-expr/src/array_expressions.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index d39658108337..0a7631918804 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -455,7 +455,7 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) } - _ => not_impl_err!( + _ => exec_err!( "array_element does not support type: {:?}", args[0].data_type() ), @@ -571,7 +571,7 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let to_array = as_int64_array(&args[2])?; general_array_slice::(array, from_array, to_array) } - _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + _ => exec_err!("array_slice does not support type: {:?}", array_data_type), } } @@ -1335,7 +1335,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { general_positions::(arr, element) } array_type => { - not_impl_err!("array_positions does not support type '{array_type:?}'.") + exec_err!("array_positions does not support type '{array_type:?}'.") } } }