Skip to content

Commit

Permalink
Refactor codebase to add datetime functions module
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Sep 19, 2024
1 parent b7f73bc commit 98ab2f2
Show file tree
Hide file tree
Showing 40 changed files with 1,533 additions and 313 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ edition = "2021"

[workspace.dependencies]
sqlparser = { path = "sqlparser" }
parquet = "52.0.0"
arrow = "52.0.0"
parquet = "53.0.0"
arrow = "53.0.0"
url = "2.5.0"
log = "^0.4"
dashmap = "6.0.1"
Expand Down
23 changes: 19 additions & 4 deletions qurious/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@ parquet = { workspace = true }
arrow = { workspace = true }
url = { workspace = true }
dashmap = { workspace = true }
connectorx = { workspace = true, features = ["src_postgres", "dst_arrow"] }
log = { workspace = true }
postgres = "0.19.8"

itertools = "0.13.0"
rayon = "1.10.0"

connectorx = { optional = true, workspace = true, features = [
"src_postgres",
"dst_arrow",
] }
postgres = { version = "0.19.8", optional = true }
rayon = { version = "1.10.0", optional = true }


[features]
connectorx = [
"connectorx/src_postgres",
"connectorx/dst_arrow",
"postgres",
"rayon",
]

[dev-dependencies]
arrow = { version = "52.0.0", features = ["prettyprint", "test_utils"] }
arrow = { workspace = true, features = ["prettyprint", "test_utils"] }
async-trait = "0.1.81"
env_logger = "0.11.5"
sqllogictest = "0.21.0"
rayon = { version = "1.10.0" }


[[test]]
harness = false
Expand Down
32 changes: 32 additions & 0 deletions qurious/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;

use arrow::array::AsArray;
use arrow::compute::filter_record_batch;
use arrow::datatypes::Schema;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
Expand All @@ -11,6 +13,7 @@ use crate::datatypes::scalar::ScalarValue;
use crate::error::Error;
use crate::error::Result;
use crate::logical::expr::LogicalExpr;
use crate::physical::expr::PhysicalExpr;
use crate::physical::plan::PhysicalPlan;
use crate::provider::table::TableProvider;

Expand Down Expand Up @@ -83,4 +86,33 @@ impl TableProvider for MemoryTable {

Ok(input_batch.iter().map(|batch| batch.num_rows()).sum::<usize>() as u64)
}

fn delete(&self, filter: Option<Arc<dyn PhysicalExpr>>) -> Result<u64> {
let mut data = self
.data
.write()
.map_err(|e| Error::InternalError(format!("delete error: {}", e)))?;

if let Some(predicate) = filter {
let new_batch = data
.iter()
.map(|batch| {
let mask = predicate.evaluate(batch)?;
let mask = arrow::compute::not(mask.as_boolean())?;
let filtered_batch = filter_record_batch(batch, &mask)?;
Ok(filtered_batch)
})
.collect::<Result<Vec<RecordBatch>>>()?;

data.clear();
data.extend(new_batch);

Ok(data.iter().map(|batch| batch.num_rows()).sum::<usize>() as u64)
} else {
let row_effected = data.iter().map(|batch| batch.num_rows()).sum::<usize>() as u64;
data.clear();

Ok(row_effected)
}
}
}
1 change: 1 addition & 0 deletions qurious/src/datasource/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod file;
pub mod memory;
#[cfg(feature="connectorx")]
pub mod connectorx;


2 changes: 1 addition & 1 deletion qurious/src/datatypes/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl Display for ScalarValue {
ScalarValue::Float64(None) => write!(f, "null"),
ScalarValue::Float32(Some(v)) => write!(f, "Float32({})", v),
ScalarValue::Float32(None) => write!(f, "null"),
ScalarValue::Utf8(Some(v)) => write!(f, "Utf8({})", v),
ScalarValue::Utf8(Some(v)) => write!(f, "Utf8('{}')", v),
ScalarValue::Utf8(None) => write!(f, "null"),
}
}
Expand Down
2 changes: 2 additions & 0 deletions qurious/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub enum Error {
InternalError(String),
ColumnNotFound(String),
DuplicateColumn(String),
InvalidArgumentError(String),
CompareError(String),
ComputeError(String),
ArrowError(ArrowError, Option<String>),
Expand Down Expand Up @@ -88,6 +89,7 @@ impl Display for Error {
Error::PlanError(e) => write!(f, "Plan Error: {}", e),
Error::DuplicateColumn(c) => write!(f, "Duplicate column: {}", c),
Error::TableNotFound(e) => write!(f, "Table Not Found: {}", e),
Error::InvalidArgumentError(e) => write!(f, "Invalid Argument Error: {}", e),
}
}
}
4 changes: 0 additions & 4 deletions qurious/src/execution/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ impl CatalogProviderList {
Ok(self.catalogs.insert(name.to_owned(), catalog))
}

pub fn deregister_catalog(&self, name: &str) -> Result<Option<Arc<dyn CatalogProvider>>> {
Ok(self.catalogs.remove(name).map(|(_, v)| v))
}

pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.catalogs.get(name).map(|v| v.value().clone())
}
Expand Down
86 changes: 60 additions & 26 deletions qurious/src/execution/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::vec;

use arrow::array::RecordBatch;
Expand All @@ -8,7 +8,11 @@ use sqlparser::parser::{Parser, TableInfo};
use crate::common::table_relation::TableRelation;
use crate::datasource::memory::MemoryTable;
use crate::error::Error;
use crate::logical::plan::{CreateMemoryTable, DdlStatement, DmlOperator, DmlStatement, DropTable, LogicalPlan};
use crate::functions::{all_builtin_functions, UserDefinedFunction};
use crate::internal_err;
use crate::logical::plan::{
CreateMemoryTable, DdlStatement, DmlOperator, DmlStatement, DropTable, Filter, LogicalPlan,
};
use crate::optimizer::Optimzier;
use crate::planner::sql::{parse_csv_options, parse_file_path, SqlQueryPlanner};
use crate::planner::QueryPlanner;
Expand All @@ -29,6 +33,7 @@ pub struct ExecuteSession {
table_factory: DefaultTableFactory,
catalog_list: CatalogProviderList,
optimizer: Optimzier,
udfs: RwLock<HashMap<String, Arc<dyn UserDefinedFunction>>>,
}

impl ExecuteSession {
Expand All @@ -39,6 +44,12 @@ impl ExecuteSession {
pub fn new_with_config(config: SessionConfig) -> Result<Self> {
let catalog_list = CatalogProviderList::default();
let catalog: Arc<dyn CatalogProvider> = Arc::new(MemoryCatalogProvider::default());
let udfs = RwLock::new(
all_builtin_functions()
.into_iter()
.map(|udf| (udf.name().to_uppercase().to_string(), udf))
.collect(),
);

catalog.register_schema(&config.default_schema, Arc::new(MemorySchemaProvider::default()))?;
catalog_list.register_catalog(&config.default_catalog, catalog)?;
Expand All @@ -49,6 +60,7 @@ impl ExecuteSession {
catalog_list,
table_factory: DefaultTableFactory::new(),
optimizer: Optimzier::new(),
udfs,
})
}

Expand All @@ -58,8 +70,12 @@ impl ExecuteSession {
let stmt = parser.parse().map_err(|e| Error::SQLParseError(e))?;
// register tables for statement if there are any file source tables to be registered
let relations = self.resolve_tables(parser.tables)?;
let udfs = &self
.udfs
.read()
.map_err(|e| Error::InternalError(format!("failed to get udfs: {}", e)))?;
// create logical plan
SqlQueryPlanner::create_logical_plan(stmt, relations)
SqlQueryPlanner::create_logical_plan(stmt, relations, udfs)
.and_then(|logical_plan| self.execute_logical_plan(&logical_plan))
}

Expand All @@ -68,20 +84,8 @@ impl ExecuteSession {

match &plan {
LogicalPlan::Ddl(ddl) => self.execute_ddl(ddl),
LogicalPlan::Dml(DmlStatement {
relation, op, input, ..
}) => {
let source = self.find_table_provider(relation)?;
let input = self.planner.create_physical_plan(input)?;

let rows_affected = match op {
DmlOperator::Insert => source.insert(input)?,
_ => todo!(),
};

Ok(vec![make_count_batch(rows_affected)])
}
plan => self.planner.create_physical_plan(&plan)?.execute(),
LogicalPlan::Dml(stmt) => self.execute_dml(stmt),
plan => self.planner.create_physical_plan(plan)?.execute(),
}
}

Expand All @@ -96,9 +100,41 @@ impl ExecuteSession {
pub fn register_catalog(&self, name: &str, catalog_provider: Arc<dyn CatalogProvider>) -> Result<()> {
self.catalog_list.register_catalog(name, catalog_provider).map(|_| ())
}
pub fn register_udf(&self, name: &str, udf: Arc<dyn UserDefinedFunction>) -> Result<()> {
let mut udfs = self
.udfs
.write()
.map_err(|e| Error::InternalError(format!("failed to register udf: {}", e)))?;
udfs.insert(name.to_string(), udf);
Ok(())
}
}

impl ExecuteSession {
fn execute_dml(&self, stmt: &DmlStatement) -> Result<Vec<RecordBatch>> {
let source = self.find_table_provider(&stmt.relation)?;
let rows_affected = match stmt.op {
DmlOperator::Insert => self.execute_insert(source, &stmt.input),
DmlOperator::Delete => self.execute_delete(source, &stmt.input),
_ => internal_err!("Unsupported DML {} operation", stmt.op),
}?;

Ok(vec![make_count_batch(rows_affected)])
}

fn execute_delete(&self, source: Arc<dyn TableProvider>, input: &LogicalPlan) -> Result<u64> {
let predicate = if let LogicalPlan::Filter(Filter { input, expr }) = input {
Some(self.planner.create_physical_expr(&input.schema(), expr)?)
} else {
None
};
source.delete(predicate)
}

fn execute_insert(&self, source: Arc<dyn TableProvider>, input: &LogicalPlan) -> Result<u64> {
let physical_plan = self.planner.create_physical_plan(input)?;
source.insert(physical_plan)
}
/// Resolve tables from the table registry
/// If the table is not found in the registry, an error is returned
/// Inspire by Datafusion implementation, but more simpllify. We decided separate table into a normal database schema, eg: catalog.schema.table
Expand Down Expand Up @@ -198,11 +234,7 @@ impl ExecuteSession {

#[cfg(test)]
mod tests {
use crate::{
build_schema,
datasource::{connectorx::postgres::PostgresCatalogProvider, memory::MemoryTable},
test_utils::assert_batch_eq,
};
use crate::{build_schema, datasource::memory::MemoryTable, test_utils::assert_batch_eq};
use arrow::{
array::{Int32Array, StringArray},
util::pretty::print_batches,
Expand All @@ -221,10 +253,10 @@ mod tests {
#[test]
fn test_create_table() -> Result<()> {
let session = ExecuteSession::new()?;
session.sql("create table t(v1 int not null, v2 int not null, v3 int not null)")?;
session.sql("insert into t values(1,4,2), (2,3,3), (3,4,4), (4,3,5)")?;
session.sql("create table t(v1 int, v2 int);")?;
session.sql("insert into t values (1, 1), (null, 2), (null, 3), (4, 4);")?;

let batch = session.sql("select count(v3) = min(v3),count(v3),min(v3) from t group by v2")?;
let batch = session.sql("select v2 from t where v1 is null")?;

print_batches(&batch)?;

Expand Down Expand Up @@ -301,9 +333,11 @@ mod tests {
}

#[test]
#[cfg(feature = "connectorx")]
fn test_postgres() {
let session = ExecuteSession::new().unwrap();
use crate::datasource::connectorx::postgres::PostgresCatalogProvider;

let session = ExecuteSession::new().unwrap();
let catalog = PostgresCatalogProvider::try_new("postgresql://root:root@localhost:5433/qurious").unwrap();

session.register_catalog("qurious", Arc::new(catalog)).unwrap();
Expand Down
62 changes: 62 additions & 0 deletions qurious/src/functions/datetime/extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::str::FromStr;

use arrow::array::Array;
use arrow::compute::kernels::cast_utils::IntervalUnit;
use arrow::{
array::{ArrayRef, AsArray},
compute::{cast, date_part, DatePart},
datatypes::DataType,
};

use crate::{arrow_err, internal_err};
use crate::{
error::{Error, Result},
functions::UserDefinedFunction,
};

#[derive(Debug)]
pub struct DatetimeExtract;

impl UserDefinedFunction for DatetimeExtract {
fn name(&self) -> &str {
"EXTRACT"
}

fn return_type(&self) -> DataType {
DataType::Int64
}

fn eval(&self, args: Vec<ArrayRef>) -> Result<ArrayRef> {
if args.len() != 2 {
return Err(Error::InvalidArgumentError("EXTRACT requires 2 arguments".to_string()));
}

// get interval_unit value
let interval_unit = if let Some(val) = args.get(0) {
val.as_string::<i32>().value(0)
} else {
return Err(Error::InvalidArgumentError(
"First argument of `DATE_PART` must be non-null scalar Utf8".to_string(),
));
};

match IntervalUnit::from_str(interval_unit)? {
IntervalUnit::Year => date_part_f64(args[1].as_ref(), DatePart::Year),
IntervalUnit::Month => date_part_f64(args[1].as_ref(), DatePart::Month),
IntervalUnit::Week => date_part_f64(args[1].as_ref(), DatePart::Week),
IntervalUnit::Day => date_part_f64(args[1].as_ref(), DatePart::Day),
IntervalUnit::Hour => date_part_f64(args[1].as_ref(), DatePart::Hour),
IntervalUnit::Minute => date_part_f64(args[1].as_ref(), DatePart::Minute),
IntervalUnit::Second => date_part_f64(args[1].as_ref(), DatePart::Second),
IntervalUnit::Millisecond => date_part_f64(args[1].as_ref(), DatePart::Millisecond),
IntervalUnit::Microsecond => date_part_f64(args[1].as_ref(), DatePart::Microsecond),
IntervalUnit::Nanosecond => date_part_f64(args[1].as_ref(), DatePart::Nanosecond),
// century and decade are not supported by `DatePart`, although they are supported in postgres
_ => internal_err!("Date part '{}' not supported", interval_unit),
}
}
}

fn date_part_f64(array: &dyn Array, part: DatePart) -> Result<ArrayRef> {
cast(date_part(array, part)?.as_ref(), &DataType::Int64).map_err(|e| arrow_err!(e))
}
1 change: 1 addition & 0 deletions qurious/src/functions/datetime/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod extract;
Loading

0 comments on commit 98ab2f2

Please sign in to comment.