Skip to content

Commit

Permalink
feat: group by
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed May 4, 2024
1 parent 84cd82d commit 51491ac
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 107 deletions.
1 change: 0 additions & 1 deletion src/datasource/file/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use url::Url;

use crate::error::{Error, Result};


pub trait DataFilePath {
fn to_url(self) -> Result<Url>;
}
Expand Down
1 change: 0 additions & 1 deletion src/datatypes/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use arrow::{
datatypes::{DataType, Field},
};

use crate::error::Result;
use std::{fmt::Display, sync::Arc};

#[derive(Debug, Clone, PartialEq, PartialOrd)]
Expand Down
3 changes: 2 additions & 1 deletion src/logical/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::sync::Arc;

pub use aggregate::{AggregateExpr, AggregateOperator};

use arrow::datatypes::{Field, FieldRef};
use arrow::datatypes::{DataType, Field, FieldRef};
pub use binary::*;
pub use column::*;
pub use literal::*;
Expand Down Expand Up @@ -41,6 +41,7 @@ impl LogicalExpr {
LogicalExpr::AggregateExpr(a) => a.field(plan),
LogicalExpr::Literal(v) => Ok(Arc::new(v.to_field())),
LogicalExpr::Alias(a) => a.expr.field(plan),
LogicalExpr::Wildcard => Ok(Arc::new(Field::new("*", DataType::Null, true))),
_ => Err(Error::InternalError(format!(
"Cannot determine schema for expression: {:?}",
self
Expand Down
9 changes: 3 additions & 6 deletions src/logical/plan/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ pub struct Aggregate {

impl Aggregate {
pub fn try_new(input: LogicalPlan, group_expr: Vec<LogicalExpr>, aggr_expr: Vec<AggregateExpr>) -> Result<Self> {
let schema = group_expr
.iter()
.map(|f| f.field(&input))
.collect::<Result<Vec<_>>>()
.map(|fields| Arc::new(Schema::new(fields)))?;
let group_fields = group_expr.iter().map(|f| f.field(&input));
let agg_fields = aggr_expr.iter().map(|f| &f.expr).map(|f| f.field(&input));

Ok(Self {
schema,
schema: Arc::new(Schema::new(group_fields.chain(agg_fields).collect::<Result<Vec<_>>>()?)),
input: Box::new(input),
group_expr,
aggr_expr,
Expand Down
2 changes: 0 additions & 2 deletions src/logical/plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ pub use sub_query::SubqueryAlias;

use arrow::datatypes::SchemaRef;

use super::expr::Column;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmptyRelation {
schema: SchemaRef,
Expand Down
2 changes: 1 addition & 1 deletion src/physical/expr/aggregate/max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fmt::Display, sync::Arc};

use arrow::array::{ArrayRef, AsArray};
use arrow::compute;
use arrow::datatypes::{DataType, Int32Type};
use arrow::datatypes::Int32Type;

use super::{Accumulator, AggregateExpr};
use crate::datatypes::scalar::ScalarValue;
Expand Down
2 changes: 1 addition & 1 deletion src/physical/expr/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Display for BinaryExpr {

#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::DataType;

use super::*;
use crate::{build_schema, physical::expr::Column, test_utils::build_record_i32};
Expand Down
195 changes: 107 additions & 88 deletions src/physical/plan/aggregate/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use crate::{
use std::{
collections::HashMap,
fmt::Display,
hash::{DefaultHasher, Hash, Hasher, RandomState},
hash::{DefaultHasher, Hash, Hasher},
sync::Arc,
};

use arrow::{
array::{make_array, make_builder, ArrayRef, AsArray, Int32Builder, Int64Array, PrimitiveArray, UInt64Array},
array::{ArrayRef, AsArray, UInt64Array},
compute,
datatypes::{ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int8Type, SchemaRef},
datatypes::{ArrowPrimitiveType, DataType, FieldRef, Int16Type, Int8Type, SchemaRef},
record_batch::RecordBatch,
};

Expand Down Expand Up @@ -52,6 +52,19 @@ impl HashAggregate {
_ => unimplemented!(),
}
}

fn is_group_field(&self, field: &FieldRef) -> bool {
let metadata = field.metadata();
metadata.contains_key("IS_GROUP_FIELD")
}

fn take_group_values_by_field(&self, _group_values: &Vec<ArrayRef>, _field: &FieldRef) -> ArrayRef {
todo!()
}

fn get_agg_expr_by_field(&self, _field: &FieldRef) -> Option<&Arc<dyn AggregateExpr>> {
todo!()
}
}

impl PhysicalPlan for HashAggregate {
Expand Down Expand Up @@ -81,44 +94,37 @@ impl PhysicalPlan for HashAggregate {
// for each batch from the input executor
for batch in self.input.execute()? {
// evaluate the groupt expression
let gourp_by_values = self
let group_by_values = self
.group_exprs
.iter()
.map(|e| e.evaluate(&batch))
.collect::<Result<Vec<ArrayRef>>>()?;
let group_indices = group_indices(&gourp_by_values)?;
// evalute the expressions that are inputs to the aggregate functions
let agg_input_values = self
.aggregate_exprs
.iter()
.map(|f| f.expression().evaluate(&batch))
.collect::<Result<Vec<ArrayRef>>>()?;
// for each row in the batch perform accumulation

let mut columns = vec![];
for (i, agg_expr) in self.aggregate_exprs.iter().enumerate() {
let mut array = vec![];
let mut acc = agg_expr.create_accumulator();

for indices in &group_indices {
compute::take(&agg_input_values[i], indices, None)
.map_err(|err| Error::ArrowError(err))
.and_then(|input_values| acc.accumluate(&input_values))?;

let single_array = acc.evaluate().map(|v| v.to_array(1))?;

if let Some(latest) = array.pop() {
array.push(compute::concat(&[&single_array, &latest])?);
} else {
array.push(single_array);
}
let (group_array, group_indices) = group_indices(&group_by_values)?;

columns.extend(group_array);

for agg_expr in &self.aggregate_exprs {
let mut agg_array = vec![];
for group in &group_indices {
let mut acc = agg_expr.create_accumulator();
// every aggregate expression only evaluate one value as output
let input_array = agg_expr
.expression()
.evaluate(&batch)
.and_then(|values| compute::take(&values, &group, None).map_err(|e| Error::ArrowError(e)))?;
acc.accumluate(&input_array)?;
agg_array.push(acc.evaluate().map(|v| v.to_array(1))?);
}
columns.extend(array);
let t = agg_array.iter().map(|f| f.as_ref()).collect::<Vec<_>>();

columns.push(compute::concat(&t)?);
}

results.push(RecordBatch::try_new(self.schema(), columns)?);
}

// create result batch containing final aggregate values
Ok(results)
}

Expand All @@ -137,51 +143,57 @@ impl Display for HashAggregate {
}
}

fn group_indices(group_values: &Vec<ArrayRef>) -> Result<Vec<UInt64Array>> {
fn group_indices(values: &Vec<ArrayRef>) -> Result<(Vec<ArrayRef>, Vec<UInt64Array>)> {
use arrow::datatypes::*;

let mut hasher_map = HashMap::new();
for values in group_values {
match values.data_type() {
DataType::UInt8 => hash_primitive_array::<UInt8Type>(values, &mut hasher_map),
DataType::UInt16 => hash_primitive_array::<UInt16Type>(values, &mut hasher_map),
DataType::UInt32 => hash_primitive_array::<UInt32Type>(values, &mut hasher_map),
DataType::UInt64 => hash_primitive_array::<UInt64Type>(values, &mut hasher_map),
DataType::Int8 => hash_primitive_array::<Int8Type>(values, &mut hasher_map),
DataType::Int16 => hash_primitive_array::<Int16Type>(values, &mut hasher_map),
DataType::Int32 => hash_primitive_array::<Int32Type>(values, &mut hasher_map),
DataType::Int64 => hash_primitive_array::<Int64Type>(values, &mut hasher_map),
for group in values.iter() {
match group.data_type() {
DataType::UInt8 => hash_primitive_array::<UInt8Type>(group, &mut hasher_map),
DataType::Int32 => hash_primitive_array::<Int32Type>(group, &mut hasher_map),
_ => {
return Err(Error::InternalError(format!(
"Unsupported data type {:?}",
values.data_type()
group.data_type()
)))
}
};
}
}
let mut indices_map = HashMap::new();

let mut map: HashMap<u64, Vec<u64>> = HashMap::new();
for (index, hasher) in hasher_map {
let hash = hasher.finish();
let indices = indices_map.entry(hash).or_insert(vec![]);
indices.push(index as u64);
if let Some(val) = map.get_mut(&hash) {
val.push(index as u64);
} else {
map.insert(hash, vec![index as u64]);
}
}

let mut group_indices = vec![];
let mut agg_indices = vec![];
for group in map.into_values() {
group_indices.push(group[0]);
agg_indices.push(UInt64Array::from_iter(group));
}
Ok(indices_map
.into_iter()
.map(|(_, v)| UInt64Array::from_iter(v))
.collect())

let mut group_values = vec![];
let group_indices = UInt64Array::from_iter(group_indices);
for v in values {
group_values.push(compute::take(v, &group_indices, None)?);
}

Ok((group_values, agg_indices))
}

fn hash_primitive_array<T: ArrowPrimitiveType>(values: &ArrayRef, hasher_map: &mut HashMap<usize, DefaultHasher>)
fn hash_primitive_array<T: ArrowPrimitiveType>(group_values: &ArrayRef, hasher_map: &mut HashMap<usize, DefaultHasher>)
where
T::Native: Hash,
{
for (index, v) in values.as_primitive::<T>().iter().enumerate() {
let mut hasher = hasher_map.entry(index).or_insert(DefaultHasher::new());
match v {
Some(key) => {
key.hash(&mut hasher);
}
None => {}
for (i, v) in group_values.as_primitive::<T>().iter().enumerate() {
if let Some(val) = v {
let mut hasher = hasher_map.entry(i).or_insert(DefaultHasher::new());
val.hash(&mut hasher);
}
}
}
Expand All @@ -190,7 +202,10 @@ where
mod tests {
use std::sync::Arc;

use arrow::{array::Int32Array, datatypes::DataType};
use arrow::{
array::{Array, Int32Array},
datatypes::DataType,
};

use crate::{
build_schema,
Expand All @@ -199,55 +214,59 @@ mod tests {
expr::MaxAggregateExpr,
plan::{tests::build_table_scan_i32, PhysicalPlan},
},
test_utils::assert_batch_eq,
};

use super::{group_indices, HashAggregate};

#[test]
fn test_group_by() {
let schema = build_schema!(("MAX(a1)", DataType::Int32), ("c1", DataType::Int32),);
let schema = build_schema!(
("c1", DataType::Int32),
("b1", DataType::Int32),
("MAX(a1)", DataType::Int32)
);

let input = build_table_scan_i32(vec![
("a1", vec![1, 2, 3]),
("b1", vec![4, 5, 6]),
("c1", vec![7, 8, 9]),
("a1", vec![1, 2, 13, 6]),
("b1", vec![4, 5, 6, 6]),
("c1", vec![7, 8, 9, 9]),
]);

// group by b1
let group_exprs = vec![Arc::new(physical::expr::Column::new("b1", 1)) as Arc<_>];
// max(a1)
let aggregate_exprs = vec![
Arc::new(MaxAggregateExpr {
expr: Arc::new(physical::expr::Column::new("a1", 0)),
}) as Arc<_>,
Arc::new(MaxAggregateExpr {
expr: Arc::new(physical::expr::Column::new("a1", 0)),
}) as Arc<_>,
// group by b1,c1
let group_exprs = vec![
Arc::new(physical::expr::Column::new("c1", 2)) as Arc<_>,
Arc::new(physical::expr::Column::new("b1", 1)) as Arc<_>,
];
// max(a1)
let aggregate_exprs = vec![Arc::new(MaxAggregateExpr {
expr: Arc::new(physical::expr::Column::new("a1", 0)),
}) as Arc<_>];

let agg = HashAggregate::new(Arc::new(schema), input, group_exprs, aggregate_exprs);

let results = agg.execute().unwrap();

assert_batch_eq(
&results,
vec![
"+------------+",
"| MAX(a1) |",
"+------------+",
"| 2 |",
"| 1 |",
"| 3 |",
"+------------+",
],
)
assert_eq!(results.len(), 1);
// assert_batch_eq(
// &results,
// vec![
// "+------------+",
// "| MAX(a1) |",
// "+------------+",
// "| 2 |",
// "| 1 |",
// "| 3 |",
// "+------------+",
// ],
// )
}

#[test]
fn test_group_indices() {
let group_field = Arc::new(Int32Array::from_iter(vec![7, 8, 9, 9]));
let results = group_indices(&vec![group_field]).unwrap();
let group_field0: Arc<dyn Array> = Arc::new(Int32Array::from_iter(vec![7, 9, 8, 9]));
let group_field1: Arc<dyn Array> = Arc::new(Int32Array::from_iter(vec![1, 3, 2, 3]));

let (_, results) = group_indices(&vec![group_field0, group_field1]).unwrap();

assert_eq!(results.len(), 3);
}
Expand Down
2 changes: 1 addition & 1 deletion src/physical/plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod scan;

pub use aggregate::HashAggregate;
pub use filter::Filter;
pub use join::{CrossJoin, Join, JoinSide,JoinFilter,ColumnIndex,join_schema};
pub use join::{join_schema, ColumnIndex, CrossJoin, Join, JoinFilter, JoinSide};
pub use projection::Projection;
pub use scan::Scan;

Expand Down
4 changes: 2 additions & 2 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl DefaultQueryPlanner {

fn physical_plan_aggregate(&self, aggregate: &Aggregate) -> Result<Arc<dyn PhysicalPlan>> {
let input = self.create_physical_plan(&aggregate.input)?;

let group_expr = aggregate
.group_expr
.iter()
Expand Down Expand Up @@ -185,7 +185,7 @@ pub(crate) fn normalize_col_with_schemas_and_ambiguity_check(
) -> Result<LogicalExpr> {
match expr {
LogicalExpr::Column(mut col) => {
if col.relation.is_some(){
if col.relation.is_some() {
return Ok(LogicalExpr::Column(col));
}

Expand Down
Loading

0 comments on commit 51491ac

Please sign in to comment.