diff --git a/src/logical/builder.rs b/src/logical/builder.rs index 3c39354..e37561f 100644 --- a/src/logical/builder.rs +++ b/src/logical/builder.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{ expr::{AggregateExpr, LogicalExpr, SortExpr}, - plan::{Aggregate, CrossJoin, EmptyRelation, Join, LogicalPlan, Projection, Sort, TableScan}, + plan::{Aggregate, CrossJoin, EmptyRelation, Join, Limit, LogicalPlan, Projection, Sort, TableScan}, }; use crate::{common::JoinType, error::Result}; use crate::{common::OwnedTableRelation, datasource::DataSource}; @@ -98,7 +98,13 @@ impl LogicalPlanBuilder { }) } - pub(crate) fn limit(&self) -> Self { - todo!() + pub fn limit(self, limit: i64, offset: i64) -> Self { + LogicalPlanBuilder { + plan: LogicalPlan::Limit(Limit { + input: Box::new(self.plan), + fetch: limit as usize, + offset: offset as usize, + }), + } } } diff --git a/src/logical/expr/mod.rs b/src/logical/expr/mod.rs index de44ab8..e25a39e 100644 --- a/src/logical/expr/mod.rs +++ b/src/logical/expr/mod.rs @@ -89,3 +89,10 @@ impl Display for LogicalExpr { } } } + +pub(crate) fn get_expr_value(expr: LogicalExpr) -> Result { + match expr { + LogicalExpr::Literal(ScalarValue::Int64(Some(v))) => Ok(v), + _ => Err(Error::InternalError(format!("Unexpected expression in"))), + } +} diff --git a/src/logical/plan/limit.rs b/src/logical/plan/limit.rs new file mode 100644 index 0000000..b7cfa28 --- /dev/null +++ b/src/logical/plan/limit.rs @@ -0,0 +1,36 @@ +use std::fmt::Display; + +use arrow::datatypes::SchemaRef; + +use crate::logical::plan::LogicalPlan; + +#[derive(Debug, Clone)] +pub struct Limit { + pub input: Box, + pub fetch: usize, + pub offset: usize, +} + +impl Display for Limit { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Limit: fetch={}, offset={}", self.fetch, self.offset) + } +} + +impl Limit { + pub fn new(input: LogicalPlan, fetch: usize, offset: usize) -> Self { + Self { + input: Box::new(input), + fetch, + offset, + } + } + + pub fn schema(&self) -> SchemaRef { + self.input.schema() + } + + pub fn children(&self) -> Option> { + self.input.children() + } +} diff --git a/src/logical/plan/mod.rs b/src/logical/plan/mod.rs index e09f5e4..c2ab1ba 100644 --- a/src/logical/plan/mod.rs +++ b/src/logical/plan/mod.rs @@ -1,6 +1,7 @@ mod aggregate; mod filter; mod join; +mod limit; mod projection; mod scan; mod sort; @@ -9,6 +10,7 @@ mod sub_query; pub use aggregate::Aggregate; pub use filter::Filter; pub use join::*; +pub use limit::Limit; pub use projection::Projection; pub use scan::TableScan; pub use sort::*; @@ -49,6 +51,8 @@ pub enum LogicalPlan { SubqueryAlias(SubqueryAlias), /// Sort the result set by the specified expressions. Sort(Sort), + /// Limit the number of rows in the result set, and optionally an offset. + Limit(Limit), } impl LogicalPlan { @@ -63,6 +67,7 @@ impl LogicalPlan { LogicalPlan::SubqueryAlias(s) => s.schema(), LogicalPlan::Join(j) => j.schema(), LogicalPlan::Sort(s) => s.schema(), + LogicalPlan::Limit(l) => l.schema(), } } @@ -77,6 +82,7 @@ impl LogicalPlan { LogicalPlan::SubqueryAlias(s) => s.children(), LogicalPlan::Join(j) => j.children(), LogicalPlan::Sort(s) => s.children(), + LogicalPlan::Limit(l) => l.children(), } } } @@ -93,6 +99,7 @@ impl std::fmt::Display for LogicalPlan { LogicalPlan::SubqueryAlias(s) => write!(f, "{}", s), LogicalPlan::Join(j) => write!(f, "{}", j), LogicalPlan::Sort(s) => write!(f, "{}", s), + LogicalPlan::Limit(l) => write!(f, "{}", l), } } } diff --git a/src/physical/plan/limit.rs b/src/physical/plan/limit.rs new file mode 100644 index 0000000..ed75e44 --- /dev/null +++ b/src/physical/plan/limit.rs @@ -0,0 +1,75 @@ +use arrow::array::RecordBatch; +use arrow::datatypes::SchemaRef; + +use crate::error::Result; +use crate::physical::plan::PhysicalPlan; + +use std::sync::Arc; + +pub struct Limit { + pub input: Arc, + pub fetch: usize, + pub offset: usize, +} + +impl Limit { + pub fn new(input: Arc, fetch: usize, offset: Option) -> Self { + Self { + input, + fetch, + offset: offset.unwrap_or_default(), + } + } +} + +impl PhysicalPlan for Limit { + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn execute(&self) -> Result> { + Ok(self + .input + .execute()? + .into_iter() + .map(|batch| batch.slice(self.offset, self.fetch)) + .collect()) + } + + fn children(&self) -> Option>> { + self.input.children() + } +} + +#[cfg(test)] +mod test { + use crate::{build_table_scan, physical::plan::PhysicalPlan, test_utils::assert_batch_eq}; + + use super::Limit; + + #[test] + fn test_limit() { + let input = build_table_scan!( + ("a", Int32Type, DataType::Int32, vec![1, 2, 3, 4]), + ("b", Float64Type, DataType::Float64, vec![1.0, 2.0, 3.0, 4.0]), + ("c", UInt64Type, DataType::UInt64, vec![1, 2, 3, 4]), + ); + + let limit = Limit::new(input, 3, Some(1)); + + let reuslts = limit.execute().unwrap(); + + assert_batch_eq( + &reuslts, + vec![ + "+---+-----+---+", + "| a | b | c |", + "+---+-----+---+", + "| 2 | 2.0 | 2 |", + "| 3 | 3.0 | 3 |", + "| 4 | 4.0 | 4 |", + "+---+-----+---+", + ], + ) + } +} diff --git a/src/physical/plan/mod.rs b/src/physical/plan/mod.rs index c190a01..5577114 100644 --- a/src/physical/plan/mod.rs +++ b/src/physical/plan/mod.rs @@ -4,12 +4,14 @@ mod join; mod projection; mod scan; mod sort; +mod limit; pub use aggregate::HashAggregate; pub use filter::Filter; pub use join::{join_schema, ColumnIndex, CrossJoin, Join, JoinFilter, JoinSide}; pub use projection::Projection; pub use scan::Scan; +pub use limit::Limit; use std::sync::Arc; diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 690e473..7cbab07 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -160,6 +160,7 @@ impl QueryPlanner for DefaultQueryPlanner { LogicalPlan::SubqueryAlias(_) => todo!(), LogicalPlan::Join(join) => self.physical_plan_join(join), LogicalPlan::Sort(_) => todo!(), + LogicalPlan::Limit(_) => todo!(), } } diff --git a/src/planner/sql.rs b/src/planner/sql.rs index a88c493..c1814ef 100644 --- a/src/planner/sql.rs +++ b/src/planner/sql.rs @@ -51,10 +51,6 @@ impl<'a> SqlQueryPlanner<'a> { } } - fn create_logical_expr(&self) -> Result { - todo!() - } - fn select_to_plan(&mut self, select: Select, mut context: &mut PlannerContext) -> Result { // process `with` clause if let Some(with) = select.with { @@ -95,10 +91,10 @@ impl<'a> SqlQueryPlanner<'a> { // process the LIMIT clause if let (Some(limit), Some(offset)) = (select.limit, select.offset) { - let limit = self.sql_to_expr(context, limit)?; - let offset = self.sql_to_expr(context, offset)?; - // Ok(LogicalPlanBuilder::from(plan).limit(limit).offset(offset)?) - Ok(plan) + let limit = self.sql_to_expr(context, limit).and_then(get_expr_value)?; + let offset = self.sql_to_expr(context, offset).and_then(get_expr_value)?; + + Ok(LogicalPlanBuilder::from(plan).limit(limit, offset).build()) } else { Ok(plan) } @@ -819,13 +815,9 @@ mod tests { #[test] fn test_limit() { let sql = "select id from person where person.id > 100 LIMIT 5 OFFSET 0;"; - let expected = "Limit: skip=0, fetch=5\ - \n Projection: person.id\ - \n Filter: person.id > Int64(100)\ - \n TableScan: person"; + let expected = "Limit: fetch=5, offset=0\n Filter: person.id > Int64(100)\n TableScan: person\n"; quick_test(sql, expected); - // Flip the order of LIMIT and OFFSET in the query. Plan should remain the same. let sql = "SELECT id FROM person WHERE person.id > 100 OFFSET 0 LIMIT 5;"; quick_test(sql, expected); } diff --git a/src/test_utils.rs b/src/test_utils.rs index d755d4d..f6567a1 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -30,11 +30,12 @@ macro_rules! build_schema { #[macro_export] macro_rules! build_table_scan { ( $(($column: expr, $data_type: ty, $f_dy: expr, $data: expr)),+$(,)? ) => { - { + { use crate::datasource::memory::MemoryDataSource; use crate::physical::plan::Scan; - use arrow::array::{Array, PrimitiveArray}; + use arrow::array::{Array,RecordBatch, PrimitiveArray}; use arrow::datatypes::*; + use std::sync::Arc; let schema = Schema::new(vec![ $(