Skip to content

Commit

Permalink
refactor: Add aggregate sum module and plan for no grouping aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Aug 27, 2024
1 parent 2956a62 commit a12dd81
Show file tree
Hide file tree
Showing 27 changed files with 359 additions and 108 deletions.
5 changes: 3 additions & 2 deletions qurious/src/datasource/file/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use arrow::csv::reader::Format;
use arrow::csv::ReaderBuilder;

use crate::arrow_err;
use crate::datasource::memory::MemoryTable;
use crate::error::{Error, Result};
use crate::provider::table::TableProvider;
Expand Down Expand Up @@ -51,7 +52,7 @@ pub fn read_csv<T: DataFilePath>(path: T, options: CsvReadOptions) -> Result<Arc
// max records set 2 means we only read the first 2 records to infer the schema
// first line is header
// second line is data to infer the data type
let (schema, _) = format.infer_schema(&mut file, None).map_err(|e| Error::ArrowError(e))?;
let (schema, _) = format.infer_schema(&mut file, None).map_err(|e| arrow_err!(e))?;

// rewind the file to the beginning because the schema inference
file.rewind().unwrap();
Expand All @@ -62,7 +63,7 @@ pub fn read_csv<T: DataFilePath>(path: T, options: CsvReadOptions) -> Result<Arc
.with_format(format)
.build(file)
.and_then(|reader| reader.into_iter().collect())
.map_err(|e| Error::ArrowError(e))
.map_err(|e| arrow_err!(e))
.and_then(|data| MemoryTable::try_new(schema, data).map(|v| Arc::new(v) as Arc<dyn TableProvider>))
}
_ => unimplemented!(),
Expand Down
3 changes: 2 additions & 1 deletion qurious/src/datasource/file/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use arrow::json::reader::infer_json_schema_from_seekable;
use arrow::json::ReaderBuilder;

use crate::arrow_err;
use crate::datasource::file::DataFilePath;
use crate::datasource::memory::MemoryTable;
use crate::error::{Error, Result};
Expand All @@ -23,7 +24,7 @@ pub fn read_json<T: DataFilePath>(path: T) -> Result<Arc<dyn TableProvider>> {
ReaderBuilder::new(schema.clone())
.build(reader)
.and_then(|builder| builder.into_iter().collect())
.map_err(|e| Error::ArrowError(e))
.map_err(|e| arrow_err!(e))
.and_then(|data| MemoryTable::try_new(schema, data).map(|v| Arc::new(v) as Arc<dyn TableProvider>))
}

Expand Down
5 changes: 3 additions & 2 deletions qurious/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use arrow::datatypes::Schema;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;

use crate::arrow_err;
use crate::datatypes::scalar::ScalarValue;
use crate::error::Error;
use crate::error::Result;
Expand Down Expand Up @@ -58,12 +59,12 @@ impl TableProvider for MemoryTable {
if let Some(projection) = projection {
let indices = projection
.iter()
.map(|name| self.schema.index_of(name).map_err(|e| Error::ArrowError(e)))
.map(|name| self.schema.index_of(name).map_err(|e| arrow_err!(e)))
.collect::<Result<Vec<_>>>()?;

batches
.iter()
.map(|batch| batch.project(&indices).map_err(|e| Error::ArrowError(e)))
.map(|batch| batch.project(&indices).map_err(|e| arrow_err!(e)))
.collect()
} else {
Ok(batches.clone())
Expand Down
42 changes: 32 additions & 10 deletions qurious/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fmt::Display;
use std::{
backtrace::{Backtrace, BacktraceStatus},
fmt::Display,
};

use arrow::error::ArrowError;
use parquet::errors::ParquetError;
Expand All @@ -13,6 +16,19 @@ macro_rules! impl_from_error {
};
}

#[macro_export]
macro_rules! arrow_err {
($ERR:expr) => {
Error::ArrowError($ERR, Some(Error::get_back_trace()))
};
}

impl_from_error!(std::io::Error);
impl_from_error!(ParquetError);
impl_from_error!(std::num::ParseIntError);
impl_from_error!(std::num::ParseFloatError);
impl_from_error!(std::str::ParseBoolError);

pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
Expand All @@ -22,34 +38,40 @@ pub enum Error {
DuplicateColumn(String),
CompareError(String),
ComputeError(String),
ArrowError(ArrowError),
ArrowError(ArrowError, Option<String>),
SQLParseError(sqlparser::error::Error),
PlanError(String),
TableNotFound(String),
}

impl Error {
#[inline(always)]
pub fn get_back_trace() -> String {
let back_trace = Backtrace::capture();
if back_trace.status() == BacktraceStatus::Captured {
return format!("{}{}", "\n\nbacktrace: ", back_trace);
}

"".to_owned()
}
}

impl std::error::Error for Error {}

impl From<ArrowError> for Error {
fn from(e: ArrowError) -> Self {
Error::ArrowError(e)
Error::ArrowError(e, None)
}
}

impl_from_error!(std::io::Error);
impl_from_error!(ParquetError);
impl_from_error!(std::num::ParseIntError);
impl_from_error!(std::num::ParseFloatError);
impl_from_error!(std::str::ParseBoolError);

impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::InternalError(e) => write!(f, "Internal Error: {}", e),
Error::ColumnNotFound(e) => write!(f, "Column Not Found: {}", e),
Error::CompareError(e) => write!(f, "Compare Error: {}", e),
Error::ComputeError(e) => write!(f, "Compute Error: {}", e),
Error::ArrowError(e) => write!(f, "Arrow Error: {}", e),
Error::ArrowError(e, msg) => write!(f, "Arrow Error: {}, msg: {}", e, msg.clone().unwrap_or_default()),
Error::SQLParseError(e) => write!(f, "SQL Parse Error: {}", e),
Error::PlanError(e) => write!(f, "Plan Error: {}", e),
Error::DuplicateColumn(c) => write!(f, "Duplicate column: {}", c),
Expand Down
28 changes: 13 additions & 15 deletions qurious/src/execution/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl ExecuteSession {
fn execute_ddl(&self, ddl: &DdlStatement) -> Result<Vec<RecordBatch>> {
match ddl {
DdlStatement::CreateMemoryTable(CreateMemoryTable { schema, name, input }) => {
let table: TableRelation = name.as_str().into();
let table: TableRelation = name.to_ascii_lowercase().into();
let schema_provider = self.find_schema_provider(&table)?;
let batch = self.execute_logical_plan(input)?;

Expand All @@ -174,15 +174,15 @@ impl ExecuteSession {
.map(|_| vec![])
}
DdlStatement::DropTable(DropTable { name, if_exists }) => {
let table: TableRelation = name.as_str().into();
let table: TableRelation = name.to_ascii_lowercase().into();
let schema_provider = self.find_schema_provider(&table)?;
let provider = schema_provider.deregister_table(table.table())?;

if provider.is_some() || *if_exists {
Ok(vec![])
} else {
Err(Error::PlanError(format!(
"Table not found: {}",
"Drop table failed, table not found: {}",
table.to_quanlify_name()
)))
}
Expand All @@ -198,7 +198,10 @@ mod tests {
datasource::{connectorx::postgres::PostgresCatalogProvider, memory::MemoryTable},
test_utils::assert_batch_eq,
};
use arrow::array::{Int32Array, StringArray};
use arrow::{
array::{Int32Array, StringArray},
util::pretty::print_batches,
};

use super::*;

Expand All @@ -213,19 +216,14 @@ mod tests {
#[test]
fn test_create_table() -> Result<()> {
let session = ExecuteSession::new()?;
let sql = r#"create table t(v1 int)"#;
let sql = r#"create table t(v1 int not null, v2 int not null, v3 int not null)"#;

session.sql(sql)?;
session.sql("insert into T values (1)")?;

let batch = session.sql("select * from T")?;
assert_batch_eq(&batch, vec![
"+----+",
"| v1 |",
"+----+",
"| 1 |",
"+----+",
]);
session.sql("insert into t values(1,4,2), (2,3,3), (3,4,4), (4,3,5)")?;

let batch = session.sql("select sum(v1), sum(v2) from t")?;

print_batches(&batch)?;

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion qurious/src/logical/expr/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl BinaryExpr {

pub fn field(&self, plan: &LogicalPlan) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
self.op.to_string(),
format!("{} {} {}", self.left, self.op, self.right),
match self.op {
Operator::Eq
| Operator::NotEq
Expand Down
3 changes: 2 additions & 1 deletion qurious/src/logical/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use arrow::datatypes::{FieldRef, SchemaRef};

use crate::arrow_err;
use crate::common::table_relation::TableRelation;
use crate::error::{Error, Result};
use crate::logical::plan::LogicalPlan;
Expand All @@ -28,7 +29,7 @@ impl Column {
plan.schema()
.field_with_name(&self.name)
.map(|f| Arc::new(f.clone()))
.map_err(|e| Error::ArrowError(e))
.map_err(|e| arrow_err!(e))
}

pub fn quanlified_name(&self) -> String {
Expand Down
2 changes: 2 additions & 0 deletions qurious/src/logical/plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ impl Projection {
LogicalExpr::Column(i) => Some(i.field(&input)),
LogicalExpr::Literal(i) => Some(Ok(Arc::new(i.to_field()))),
LogicalExpr::Alias(i) => Some(i.expr.field(&input)),
LogicalExpr::AggregateExpr(i) => Some(i.field(&input)),
LogicalExpr::BinaryExpr(i) => Some(i.field(&input)),
a => todo!("Projection::try_new: {:?}", a),
})
.collect::<Result<Vec<FieldRef>>>()
Expand Down
3 changes: 2 additions & 1 deletion qurious/src/logical/plan/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{fmt::Display, sync::Arc};

use arrow::datatypes::{Schema, SchemaRef};

use crate::arrow_err;
use crate::common::table_relation::TableRelation;
use crate::error::{Error, Result};
use crate::logical::expr::LogicalExpr;
Expand Down Expand Up @@ -35,7 +36,7 @@ impl TableScan {
source
.schema()
.field_with_name(name)
.map_err(|err| Error::ArrowError(err))
.map_err(|err| arrow_err!(err))
.cloned()
})
.collect::<Result<Vec<_>>>()
Expand Down
1 change: 1 addition & 0 deletions qurious/src/physical/expr/aggregate/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod max;
pub mod sum;

use arrow::array::ArrayRef;

Expand Down
91 changes: 91 additions & 0 deletions qurious/src/physical/expr/aggregate/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::fmt::Debug;
use std::{fmt::Display, sync::Arc};

use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
use arrow::compute;
use arrow::datatypes::{ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type};

use super::{Accumulator, AggregateExpr};
use crate::error::{Error, Result};
use crate::{datatypes::scalar::ScalarValue, physical::expr::PhysicalExpr};

#[derive(Debug)]
pub struct SumAggregateExpr {
pub expr: Arc<dyn PhysicalExpr>,
pub return_type: DataType,
}

impl SumAggregateExpr {
pub fn new(expr: Arc<dyn PhysicalExpr>, return_type: DataType) -> Self {
Self { expr, return_type }
}
}

impl Display for SumAggregateExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SUM({})", self.expr)
}
}

impl AggregateExpr for SumAggregateExpr {
fn expression(&self) -> &Arc<dyn PhysicalExpr> {
&self.expr
}

fn create_accumulator(&self) -> Box<dyn Accumulator> {
match self.return_type {
DataType::UInt64 => Box::new(SumAccumulator::<UInt64Type>::new()),
DataType::Int64 => Box::new(SumAccumulator::<Int64Type>::new()),
DataType::Float64 => Box::new(SumAccumulator::<Float64Type>::new()),
DataType::Decimal128(_, _) => Box::new(SumAccumulator::<Decimal128Type>::new()),
DataType::Decimal256(_, _) => Box::new(SumAccumulator::<Decimal256Type>::new()),
_ => {
unimplemented!("Sum not supported for {}: {}", self.expr, self.return_type)
}
}
}
}

struct SumAccumulator<T: ArrowNumericType> {
sum: Option<T::Native>,
}

impl<T: ArrowNumericType> SumAccumulator<T> {
pub fn new() -> Self {
Self { sum: None }
}
}

impl<T: ArrowNumericType> Debug for SumAccumulator<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SumAccumulator")
}
}

impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
fn accumluate(&mut self, value: &ArrayRef) -> Result<()> {
let values = value.as_primitive::<T>();

if let Some(x) = compute::sum(values) {
let v = self.sum.get_or_insert(T::Native::default());
*v = v.add_wrapping(x);
}

Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let val = self
.sum
.ok_or(Error::InternalError("No value to evaluate".to_string()))?;

match T::DATA_TYPE {
DataType::UInt64 => Ok(ScalarValue::UInt64(val.to_usize().map(|x| x as u64))),
DataType::Int64 => Ok(ScalarValue::Int64(val.to_usize().map(|x| x as i64))),
// DataType::Float64 => Ok(ScalarValue::Float64(val)),
// DataType::Decimal128(_, _) => Box::new(SumAccumulator::<Decimal128Type>::new()),
// DataType::Decimal256(_, _) => Box::new(SumAccumulator::<Decimal256Type>::new()),
_ => Err(Error::InternalError(format!("Sum not supported for {}", T::DATA_TYPE))),
}
}
}
Loading

0 comments on commit a12dd81

Please sign in to comment.