Skip to content

Commit

Permalink
feat: Add support for bigint and table schema
Browse files Browse the repository at this point in the history
This commit adds support for bigint data type and introduces the table_schema module. The table_schema module defines the TableSchema struct, which contains information about the schema and qualified name of a table.

fix: Fix order by alias and cte table issues

This commit fixes issues related to ordering by alias and common table expressions (CTEs).

refactor: Refactor merge_schema function

The merge_schema function in the utils module has been refactored to improve readability and maintainability.

test: Add count wildcard rule optimization

This commit adds a count wildcard rule optimization to the optimizer. The optimization replaces count(*) expressions with count(1) to improve query performance.

chore: Update parser to support column aliases

The parser has been updated to support column aliases in SELECT statements.

docs: Update documentation for binary expressions

The documentation for the BinaryExpr struct has been updated to provide more clarity on its usage and behavior.
  • Loading branch information
holicc committed Oct 9, 2024
1 parent 1eebc8f commit 4e445fb
Show file tree
Hide file tree
Showing 23 changed files with 1,053 additions and 728 deletions.
2 changes: 2 additions & 0 deletions qurious/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod join_type;
pub mod table_relation;
pub mod table_schema;
pub mod transformed;
11 changes: 11 additions & 0 deletions qurious/src/common/table_schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use std::sync::Arc;

use super::table_relation::TableRelation;
use arrow::datatypes::SchemaRef;

pub type TableSchemaRef = Arc<TableSchema>;

pub struct TableSchema {
pub schema: SchemaRef,
pub qualified_name: Vec<Option<TableRelation>>,
}
109 changes: 109 additions & 0 deletions qurious/src/common/transformed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use crate::error::Result;

pub struct Transformed<T> {
pub data: T,
pub transformed: bool,
}

impl<T> Transformed<T> {
pub fn yes(data: T) -> Self {
Transformed {
data,
transformed: true,
}
}

pub fn no(data: T) -> Self {
Transformed {
data,
transformed: false,
}
}

pub fn update<U, F>(self, f: F) -> Transformed<U>
where
F: FnOnce(T) -> U,
{
Transformed {
data: f(self.data),
transformed: self.transformed,
}
}

pub fn transform_children<F>(self, f: F) -> Result<Transformed<T>>
where
F: FnOnce(T) -> Result<Transformed<T>>,
{
f(self.data).map(|mut t| {
t.transformed |= self.transformed;
t
})
}
}

pub trait TransformedResult<T> {
fn data(self) -> Result<T>;
}

impl<T> TransformedResult<T> for Result<Transformed<T>> {
fn data(self) -> Result<T> {
self.map(|t| t.data)
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum TreeNodeRecursion {
Continue,
Stop,
}

impl TreeNodeRecursion {
pub fn visit_children<F>(self, f: F) -> Result<TreeNodeRecursion>
where
F: FnOnce() -> Result<TreeNodeRecursion>,
{
match self {
TreeNodeRecursion::Continue => f(),
TreeNodeRecursion::Stop => Ok(self),
}
}
}

pub trait TransformNode: Sized + Clone {
fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
where
F: FnMut(Self) -> Result<Transformed<Self>>;

fn apply_children<'n, F>(&'n self, f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'n Self) -> Result<TreeNodeRecursion>;

fn transform<F>(self, mut f: F) -> Result<Transformed<Self>>
where
F: FnMut(Self) -> Result<Transformed<Self>>,
{
transform_impl(self, &mut f)
}

fn apply<'n, F>(&'n self, mut f: F) -> Result<TreeNodeRecursion>
where
F: FnMut(&'n Self) -> Result<TreeNodeRecursion>,
{
apply_impl(self, &mut f)
}
}

fn transform_impl<N, F>(node: N, f: &mut F) -> Result<Transformed<N>>
where
N: TransformNode,
F: FnMut(N) -> Result<Transformed<N>>,
{
f(node.clone())?.transform_children(|n| n.map_children(|c| transform_impl(c, f)))
}

fn apply_impl<'n, N: TransformNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}
4 changes: 2 additions & 2 deletions qurious/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use crate::{
error::Result,
logical::{
expr::{self, LogicalExpr},
expr::LogicalExpr,
plan::{Aggregate, Filter, LogicalPlan, Projection},
},
planner::QueryPlanner,
Expand Down Expand Up @@ -50,7 +50,7 @@ impl DataFrame {
})
}

pub fn aggregate(self, group_by: Vec<LogicalExpr>, aggr_expr: Vec<expr::AggregateExpr>) -> Result<Self> {
pub fn aggregate(self, group_by: Vec<LogicalExpr>, aggr_expr: Vec<LogicalExpr>) -> Result<Self> {
Aggregate::try_new(self.plan.clone(), group_by, aggr_expr).map(|a| Self {
planner: self.planner,
plan: LogicalPlan::Aggregate(a),
Expand Down
23 changes: 8 additions & 15 deletions qurious/src/execution/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::internal_err;
use crate::logical::plan::{
CreateMemoryTable, DdlStatement, DmlOperator, DmlStatement, DropTable, Filter, LogicalPlan,
};
use crate::optimizer::Optimzier;
use crate::optimizer::Optimizer;
use crate::planner::sql::{parse_csv_options, parse_file_path, SqlQueryPlanner};
use crate::planner::QueryPlanner;
use crate::provider::catalog::CatalogProvider;
Expand All @@ -32,7 +32,7 @@ pub struct ExecuteSession {
planner: Arc<dyn QueryPlanner>,
table_factory: DefaultTableFactory,
catalog_list: CatalogProviderList,
optimizer: Optimzier,
optimizer: Optimizer,
udfs: RwLock<HashMap<String, Arc<dyn UserDefinedFunction>>>,
}

Expand All @@ -59,7 +59,7 @@ impl ExecuteSession {
planner: Arc::new(DefaultQueryPlanner::default()),
catalog_list,
table_factory: DefaultTableFactory::new(),
optimizer: Optimzier::new(),
optimizer: Optimizer::new(),
udfs,
})
}
Expand Down Expand Up @@ -252,25 +252,18 @@ mod tests {
#[test]
fn test_create_table() -> Result<()> {
let session = ExecuteSession::new()?;
// session.sql("create table a(v1 int, v2 int);")?;
session.sql("create table a(v1 int, v2 int);")?;
// session.sql("create table b(v3 int, v4 int);")?;
// session.sql("create table t(v1 int not null, v2 int not null, v3 double not null)")?;
session.sql("create table t(v1 int not null, v2 int not null, v3 double not null)")?;

// session.sql("create table x(a int, b int);")?;
// session.sql("create table y(c int, d int);")?;

// println!("++++++++++++++");

// session.sql("insert into a values (1, 1), (2, 2), (3, 3);")?;
session.sql("insert into a values (1, 1), (2, 2), (3, 3);")?;
// session.sql("insert into b values (1, 100), (3, 300), (4, 400);")?;
// session.sql("select a, b, c, d from x join y on a = c")?;


let batch = session.sql("
WITH
cte AS (SELECT 42 AS i),
cte2 AS (SELECT i*100 AS x FROM cte)
SELECT * FROM cte2;")?;
println!("++++++++++++++");
let batch = session.sql("with cte as (select v1 as i from a) select i*100 from cte")?;

print_batches(&batch)?;

Expand Down
4 changes: 2 additions & 2 deletions qurious/src/logical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use arrow::datatypes::Schema;
use std::sync::Arc;

use super::{
expr::{AggregateExpr, LogicalExpr, SortExpr},
expr::{LogicalExpr, SortExpr},
plan::{Aggregate, CrossJoin, EmptyRelation, Join, Limit, LogicalPlan, Projection, Sort, TableScan},
};
use crate::{common::join_type::JoinType, provider::table::TableProvider};
Expand Down Expand Up @@ -92,7 +92,7 @@ impl LogicalPlanBuilder {
})
}

pub fn aggregate(self, group_expr: Vec<LogicalExpr>, aggr_expr: Vec<AggregateExpr>) -> Result<Self> {
pub fn aggregate(self, group_expr: Vec<LogicalExpr>, aggr_expr: Vec<LogicalExpr>) -> Result<Self> {
Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(|s| LogicalPlanBuilder::from(LogicalPlan::Aggregate(s)))
}
Expand Down
43 changes: 18 additions & 25 deletions qurious/src/logical/expr/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use arrow::datatypes::{Field, FieldRef};
use arrow::datatypes::{DataType, Field, FieldRef};

use crate::error::{Error, Result};
use crate::logical::expr::LogicalExpr;
use crate::logical::plan::LogicalPlan;
use std::convert::TryFrom;
use std::fmt::Display;
use std::sync::Arc;
use std::convert::TryFrom;

use super::Column;

Expand Down Expand Up @@ -40,7 +40,10 @@ impl TryFrom<String> for AggregateOperator {
"max" => Ok(AggregateOperator::Max),
"avg" => Ok(AggregateOperator::Avg),
"count" => Ok(AggregateOperator::Count),
_ => Err(Error::InternalError(format!("{} is not a valid aggregate operator", value))),
_ => Err(Error::InternalError(format!(
"{} is not a valid aggregate operator",
value
))),
}
}
}
Expand All @@ -61,50 +64,40 @@ pub struct AggregateExpr {

impl AggregateExpr {
pub fn field(&self, plan: &LogicalPlan) -> Result<FieldRef> {
self.expr.field(plan).map(|field| {
self.expr.field(plan).and_then(|field| {
let col_name = if let LogicalExpr::Column(inner) = self.expr.as_ref() {
&inner.quanlified_name()
} else {
field.name()
};

Arc::new(Field::new(
Ok(Arc::new(Field::new(
format!("{}({})", self.op, col_name),
field.data_type().clone(),
self.infer_type(field.data_type())?,
true,
))
)))
})
}

pub fn as_column(&self) -> Result<LogicalExpr> {
pub(crate) fn as_column(&self) -> Result<LogicalExpr> {
self.expr.as_column().map(|inner_col| {
LogicalExpr::Column(Column {
name: format!("{}({})", self.op, inner_col),
relation: None,
})
})
}

pub(crate) fn infer_type(&self, expr_data_type: &DataType) -> Result<DataType> {
match self.op {
AggregateOperator::Count => Ok(DataType::Int64),
_ => Ok(expr_data_type.clone()),
}
}
}

impl Display for AggregateExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({})", self.op, self.expr)
}
}

// macro_rules! make_aggregate_expr_fn {
// ($name: ident, $op: expr, $re: ident) => {
// pub fn $name(expr: LogicalExpr) -> $re {
// $re {
// op: $op,
// expr: Box::new(expr),
// }
// }
// };
// }

// make_aggregate_expr_fn!(sum, AggregateOperator::Sum, AggregateExpr);
// make_aggregate_expr_fn!(min, AggregateOperator::Min, AggregateExpr);
// make_aggregate_expr_fn!(max, AggregateOperator::Max, AggregateExpr);
// make_aggregate_expr_fn!(avg, AggregateOperator::Avg, AggregateExpr);
// make_aggregate_expr_fn!(count, AggregateOperator::Count, AggregateExpr);
16 changes: 7 additions & 9 deletions qurious/src/logical/expr/binary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::datatypes::{DataType, Field, FieldRef, Schema};

use crate::datatypes::operator::Operator;
use crate::error::Result;
Expand Down Expand Up @@ -28,12 +28,12 @@ impl BinaryExpr {
pub fn field(&self, plan: &LogicalPlan) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
format!("({} {} {})", self.left, self.op, self.right),
self.get_result_type(plan)?,
false,
self.get_result_type(&plan.schema())?,
true,
)))
}

pub fn get_result_type(&self, plan: &LogicalPlan) -> Result<DataType> {
pub fn get_result_type(&self, schema: &Arc<Schema>) -> Result<DataType> {
match self.op {
Operator::Eq
| Operator::NotEq
Expand All @@ -46,12 +46,10 @@ impl BinaryExpr {
return Ok(DataType::Boolean);
}
_ => {
let ll = self.left.field(plan)?;
let rr = self.right.field(plan)?;
let left_type = ll.data_type();
let right_type = rr.data_type();
let left_type = self.left.data_type(schema)?;
let right_type = self.right.data_type(schema)?;

Ok(utils::get_input_types(left_type, right_type))
Ok(utils::get_input_types(&left_type, &right_type))
}
}
}
Expand Down
34 changes: 1 addition & 33 deletions qurious/src/logical/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;

use arrow::datatypes::{FieldRef, SchemaRef};
use arrow::datatypes::FieldRef;

use crate::arrow_err;
use crate::common::table_relation::TableRelation;
Expand Down Expand Up @@ -39,38 +39,6 @@ impl Column {
self.name.clone()
}
}

pub fn normalize_col_with_schemas_and_ambiguity_check(
self,
schemas: &[&[(&TableRelation, SchemaRef)]],
) -> Result<Self> {
if self.relation.is_some() {
return Ok(self);
}

for schema_level in schemas {
let mut matched = schema_level
.iter()
.filter_map(|(relation, schema)| schema.field_with_name(&self.name).map(|f| (relation, f)).ok())
.collect::<Vec<_>>();

if matched.len() > 1 {
return Err(Error::InternalError(format!("Column \"{}\" is ambiguous", self.name)));
}

if let Some((relation, _)) = matched.pop() {
return Ok(Self {
name: self.name,
relation: Some((*relation).clone()),
});
}
}

Err(Error::InternalError(format!(
"Column \"{}\" not found in any table",
self.name
)))
}
}

impl Display for Column {
Expand Down
Loading

0 comments on commit 4e445fb

Please sign in to comment.