Skip to content

Commit

Permalink
track aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 15, 2024
1 parent 8397ad5 commit 3ba5460
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ impl CommonSubexprEliminate {
} = aggregate;
let mut expr_stats = ExprStats::new();

// track transformed information
let mut transformed = false;

// rewrite inputs
let input_schema = Arc::clone(input.schema());
let group_arrays = to_arrays(
Expand All @@ -371,15 +374,16 @@ impl CommonSubexprEliminate {
.collect::<Result<Vec<_>>>()?;

// rewrite both group exprs and aggr_expr
let (mut new_expr, new_input) = self
.rewrite_expr(
vec![group_expr, aggr_expr],
&[&group_arrays, &aggr_arrays],
unwrap_arc(input),
&expr_stats,
config,
)?
.data;
let rewritten = self.rewrite_expr(
vec![group_expr, aggr_expr],
&[&group_arrays, &aggr_arrays],
unwrap_arc(input),
&expr_stats,
config,
)?;
transformed |= rewritten.transformed;
let (mut new_expr, new_input) = rewritten.data;

// note the reversed pop order.
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;
Expand All @@ -394,15 +398,14 @@ impl CommonSubexprEliminate {
ExprMask::NormalAndAggregates,
)?;
let mut common_exprs = IndexMap::new();
let mut rewritten = self
.rewrite_exprs_list(
vec![new_aggr_expr.clone()],
&[&aggr_arrays],
&expr_stats,
&mut common_exprs,
)?
.data;
let rewritten = pop_expr(&mut rewritten)?;
let mut rewritten_exprs = self.rewrite_exprs_list(
vec![new_aggr_expr.clone()],
&[&aggr_arrays],
&expr_stats,
&mut common_exprs,
)?;
transformed |= rewritten_exprs.transformed;
let rewritten = pop_expr(&mut rewritten_exprs.data)?;

if common_exprs.is_empty() {
// Alias aggregation expressions if they have changed
Expand All @@ -411,14 +414,13 @@ impl CommonSubexprEliminate {
.zip(saved_names.into_iter())
.map(|(new_expr, saved_name)| saved_name.restore(new_expr))
.collect::<Result<Vec<Expr>>>()?;
// Since group_expr changes, schema changes also. Use try_new method.
return Aggregate::try_new(
// Since group_expr changes, schema may also. Use try_new method.
let new_agg = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(new_input),
new_group_expr,
new_aggr_expr,
)
.map(LogicalPlan::Aggregate)
.map(Transformed::yes);
)?);
return Ok(Transformed::new_transformed(new_agg, transformed));
}
let mut agg_exprs = common_exprs
.into_iter()
Expand Down Expand Up @@ -1278,7 +1280,9 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test";
let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, plan);

Expand Down

0 comments on commit 3ba5460

Please sign in to comment.