Skip to content

Commit

Permalink
a large refractory for normalizing column name
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Jul 23, 2024
1 parent a73d551 commit 224366a
Show file tree
Hide file tree
Showing 17 changed files with 363 additions and 365 deletions.
8 changes: 3 additions & 5 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
mod join_type;
mod table_relation;

pub use join_type::*;
pub use table_relation::{OwnedTableRelation, TableRelation};
pub mod join_type;
pub mod table_relation;
pub mod table_schema;
59 changes: 17 additions & 42 deletions src/common/table_relation.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,32 @@
use std::{borrow::Cow, fmt::Display};

pub type OwnedTableRelation = TableRelation<'static>;
use std::{fmt::Display, sync::Arc};

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TableRelation<'a> {
pub enum TableRelation {
/// An unqualified table reference, e.g. "table"
Bare {
/// The table name
table: Cow<'a, str>,
table: Arc<str>,
},
/// A partially resolved table reference, e.g. "schema.table"
Partial {
/// The schema containing the table
schema: Cow<'a, str>,
schema: Arc<str>,
/// The table name
table: Cow<'a, str>,
table: Arc<str>,
},
/// A fully resolved table reference, e.g. "catalog.schema.table"
Full {
/// The catalog (aka database) containing the table
catalog: Cow<'a, str>,
catalog: Arc<str>,
/// The schema containing the table
schema: Cow<'a, str>,
schema: Arc<str>,
/// The table name
table: Cow<'a, str>,
table: Arc<str>,
},
}

impl<'a> TableRelation<'a> {
pub fn to_owned(&self) -> OwnedTableRelation {
match self {
TableRelation::Bare { table } => OwnedTableRelation::Bare {
table: table.to_string().into(),
},
TableRelation::Partial { schema, table } => OwnedTableRelation::Partial {
schema: schema.to_string().into(),
table: table.to_string().into(),
},
TableRelation::Full { catalog, schema, table } => OwnedTableRelation::Full {
catalog: catalog.to_string().into(),
schema: schema.to_string().into(),
table: table.to_string().into(),
},
}
}

fn parse_str(a: &'a str) -> Self {
impl TableRelation {
fn parse_str(a: &str) -> Self {
let mut idents = a.split('.').into_iter().collect::<Vec<&str>>();

match idents.len() {
Expand Down Expand Up @@ -79,27 +60,21 @@ impl<'a> TableRelation<'a> {
}
}

impl Display for TableRelation<'_> {
impl Display for TableRelation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_quanlify_name().fmt(f)
}
}

impl<'a> From<&'a str> for TableRelation<'a> {
fn from(value: &'a str) -> Self {
TableRelation::parse_str(value)
}
}

impl From<&String> for OwnedTableRelation {
fn from(value: &String) -> Self {
TableRelation::parse_str(&value).to_owned()
impl From<String> for TableRelation {
fn from(value: String) -> Self {
TableRelation::parse_str(&value)
}
}

impl From<String> for OwnedTableRelation {
fn from(value: String) -> Self {
TableRelation::parse_str(&value).to_owned()
impl From<&str> for TableRelation {
fn from(value: &str) -> Self {
TableRelation::parse_str(value)
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/common/table_schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use std::sync::Arc;

use crate::common::table_relation::TableRelation;
use arrow::datatypes::Schema;

pub type TableSchemaRef = Arc<TableSchema>;

pub struct TableSchema {
schema: Schema,
relation: TableRelation,
}
2 changes: 1 addition & 1 deletion src/datasource/file/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl DataFilePath for &str {
}
}

fn parse_path<S: AsRef<str>>(path: S) -> Result<Url> {
pub fn parse_path<S: AsRef<str>>(path: S) -> Result<Url> {
match path.as_ref().parse::<Url>() {
Ok(url) => Ok(url),
Err(url::ParseError::RelativeUrlWithoutBase) => fs::canonicalize(path.as_ref())
Expand Down
4 changes: 2 additions & 2 deletions src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod memory;

use crate::{datatypes::scalar::ScalarValue, error::Result, logical::expr::LogicalExpr};
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use std::{collections::HashMap, fmt::Debug};
use std::fmt::Debug;

pub trait DataSource: Debug + Sync + Send {
fn schema(&self) -> SchemaRef;
Expand All @@ -15,4 +15,4 @@ pub trait DataSource: Debug + Sync + Send {
fn get_column_default(&self, _column: &str) -> Option<&ScalarValue> {
None
}
}
}
58 changes: 48 additions & 10 deletions src/execution/registry/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::datasource::DataSource;
use crate::datasource::file::csv::CsvReadOptions;
use crate::datasource::file::json::JsonReadOptions;
use crate::datasource::{file, DataSource};
use crate::error::{Error, Result};
use std::collections::HashMap;
use std::fmt::Debug;
Expand All @@ -12,24 +14,35 @@ pub trait TableRegistry: Debug + Sync + Send {
fn get_table_source(&self, name: &str) -> Result<Arc<dyn DataSource>>;
}

pub trait TableSourceFactory: Debug + Sync + Send {
fn create(&self, name: &str) -> Result<Arc<dyn DataSource>>;
}

#[derive(Debug)]
pub struct HashMapTableRegistry {
pub struct DefaultTableRegistry {
tables: HashMap<String, Arc<dyn DataSource>>,
factory: Arc<dyn TableSourceFactory>,
}

impl Default for HashMapTableRegistry {
impl Default for DefaultTableRegistry {
fn default() -> Self {
Self { tables: HashMap::new() }
Self {
tables: HashMap::new(),
factory: Arc::new(DefaultTableSourceFactory),
}
}
}

impl HashMapTableRegistry {
impl DefaultTableRegistry {
pub fn new(tables: HashMap<String, Arc<dyn DataSource>>) -> Self {
Self { tables }
Self {
tables,
factory: Arc::new(DefaultTableSourceFactory),
}
}
}

impl TableRegistry for HashMapTableRegistry {
impl TableRegistry for DefaultTableRegistry {
fn register_table(&mut self, name: &str, table: Arc<dyn DataSource>) -> Result<()> {
self.tables.insert(name.to_string(), table);
Ok(())
Expand All @@ -43,9 +56,34 @@ impl TableRegistry for HashMapTableRegistry {
}

fn get_table_source(&self, name: &str) -> Result<Arc<dyn DataSource>> {
match self.tables.get(name) {
Some(table) => Ok(table.clone()),
None => Err(Error::PlanError(format!("No table named '{}'", name))),
if let Some(table) = self.tables.get(name) {
return Ok(table.clone());
}
// try to create a dynamic table
self.factory.create(name)
}
}

#[derive(Debug)]
pub struct DefaultTableSourceFactory;

impl TableSourceFactory for DefaultTableSourceFactory {
fn create(&self, name: &str) -> Result<Arc<dyn DataSource>> {
let url = file::parse_path(name)
.map_err(|e| Error::PlanError(format!("No table named '{}' found, cause: {}", name, e.to_string())))?;

if url.scheme() != "file" {
return Err(Error::InternalError(format!("Unsupported table source: {}", name)));
}

let path = url.path().to_string();
let ext = path.split('.').last().unwrap_or_default();

match ext {
"csv" => file::csv::read_csv(path, CsvReadOptions::default()),
"json" => file::json::read_json(path, JsonReadOptions::default()),
"parquet" => file::parquet::read_parquet(path),
_ => return Err(Error::InternalError(format!("Unsupported file format: {}", name))),
}
}
}
33 changes: 26 additions & 7 deletions src/execution/session.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

use arrow::array::RecordBatch;
use sqlparser::parser::Parser;

use super::registry::TableRegistry;
use crate::common::table_relation::TableRelation;
use crate::datasource::DataSource;
use crate::error::Error;
use crate::execution::registry::HashMapTableRegistry;
use crate::execution::registry::DefaultTableRegistry;
use crate::logical::plan::LogicalPlan;
use crate::optimizer::Optimzier;
use crate::planner::sql::SqlQueryPlanner;
use crate::planner::QueryPlanner;
use crate::{error::Result, planner::DefaultQueryPlanner};

pub struct ExecuteSession {
tables: Arc<RwLock<dyn TableRegistry>>,
table_registry: Arc<RwLock<dyn TableRegistry>>,
query_planner: Box<dyn QueryPlanner>,
optimizer: Optimzier,
}

impl Default for ExecuteSession {
fn default() -> Self {
Self {
tables: Arc::new(RwLock::new(HashMapTableRegistry::default())),
table_registry: Arc::new(RwLock::new(DefaultTableRegistry::default())),
query_planner: Box::new(DefaultQueryPlanner),
optimizer: Optimzier::new(),
}
Expand All @@ -30,7 +33,13 @@ impl Default for ExecuteSession {

impl ExecuteSession {
pub fn sql(&self, sql: &str) -> Result<Vec<RecordBatch>> {
SqlQueryPlanner::create_logical_plan(self.tables.clone(), sql)
// parse sql collect tables
let mut parser = Parser::new(sql);
let stmt = parser.parse().map_err(|e| Error::SQLParseError(e))?;
// register tables for statement if there are any file source tables to be registered
let relations = self.resolve_tables(parser.relations)?;
// create logical plan
SqlQueryPlanner::create_logical_plan(stmt, relations)
.and_then(|logical_plan| self.execute_logical_plan(&logical_plan))
}

Expand All @@ -40,14 +49,24 @@ impl ExecuteSession {
}

pub fn register_table(&mut self, name: &str, table: Arc<dyn DataSource>) -> Result<()> {
self.tables
self.table_registry
.write()
.map_err(|e| Error::InternalError(e.to_string()))?
.register_table(name, table)
}

pub(crate) fn get_tables(&self) -> Arc<RwLock<dyn TableRegistry>> {
self.tables.clone()
pub(crate) fn resolve_tables(&self, tables: Vec<String>) -> Result<HashMap<TableRelation, Arc<dyn DataSource>>> {
tables
.into_iter()
.map(|r| {
let relation: TableRelation = r.into();
self.table_registry
.write()
.map_err(|e| Error::InternalError(e.to_string()))?
.get_table_source(&relation.to_quanlify_name())
.map(|table| (relation, table))
})
.collect::<Result<_>>()
}
}

Expand Down
31 changes: 22 additions & 9 deletions src/logical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use super::{
expr::{AggregateExpr, LogicalExpr, SortExpr},
plan::{Aggregate, CrossJoin, EmptyRelation, Join, Limit, LogicalPlan, Projection, Sort, TableScan},
};
use crate::{common::JoinType, error::Result};
use crate::{common::OwnedTableRelation, datasource::DataSource};
use crate::{
common::join_type::JoinType, datasource::DataSource, planner::normalize_col_with_schemas_and_ambiguity_check,
};
use crate::{common::table_relation::TableRelation, error::Result};

pub struct LogicalPlanBuilder {
plan: LogicalPlan,
Expand Down Expand Up @@ -37,7 +39,7 @@ impl LogicalPlanBuilder {
}

pub fn scan(
relation: impl Into<OwnedTableRelation>,
relation: impl Into<TableRelation>,
table_source: Arc<dyn DataSource>,
filter: Option<LogicalExpr>,
) -> Result<Self> {
Expand Down Expand Up @@ -93,12 +95,23 @@ impl LogicalPlanBuilder {
}

pub fn sort(self, order_by: Vec<SortExpr>) -> Result<Self> {
Ok(LogicalPlanBuilder {
plan: LogicalPlan::Sort(Sort {
exprs: order_by,
input: Box::new(self.plan),
}),
})
order_by
.into_iter()
.map(|sort| {
normalize_col_with_schemas_and_ambiguity_check(*sort.expr, &[&self.plan.relation()]).map(|expr| {
SortExpr {
expr: expr.into(),
asc: sort.asc,
}
})
})
.collect::<Result<_>>()
.map(|sort_exprs| LogicalPlanBuilder {
plan: LogicalPlan::Sort(Sort {
exprs: sort_exprs,
input: Box::new(self.plan),
}),
})
}

pub fn limit(self, limit: i64, offset: i64) -> Self {
Expand Down
Loading

0 comments on commit 224366a

Please sign in to comment.