martin-g commented on code in PR #21021:
URL: https://github.com/apache/datafusion/pull/21021#discussion_r2974933012
##########
datafusion/core/src/dataframe/mod.rs:
##########
@@ -410,21 +412,102 @@ impl DataFrame {
expr_list: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<DataFrame> {
let expr_list: Vec<SelectExpr> =
- expr_list.into_iter().map(|e| e.into()).collect::<Vec<_>>();
+ expr_list.into_iter().map(|e| e.into()).collect();
+ // Extract expressions
let expressions = expr_list.iter().filter_map(|e| match e {
SelectExpr::Expression(expr) => Some(expr),
_ => None,
});
- let window_func_exprs = find_window_exprs(expressions);
- let plan = if window_func_exprs.is_empty() {
+ // Apply window functions first
Review Comment:
Could you please add some test case(s) with window functions too ?
##########
datafusion/core/src/dataframe/mod.rs:
##########
Review Comment:
Maybe update the example to use also aggregate and window functions ?!
##########
datafusion/core/tests/dataframe/mod.rs:
##########
@@ -6854,3 +6857,74 @@ async fn
test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> {
+ let df = test_table().await?;
+
+ // Multiple aggregates
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"),
+ sum(col("c9")).alias("sum_c9"),
+ count(col("c8")).alias("count_c8"),
+ (sum(col("c9")) + count(col("c8"))).alias("total1"),
+ ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"),
+ (count(col("c9")) + lit(1)).alias("count_c9_add_1"),
+ ])?;
+
+ assert_batches_eq!(
+ &[
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1
| total2 | count_c9_add_1 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| 100 | 100 | 222089770060 | 100 |
222089770160 | 202 | 101 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ ],
+ &res.collect().await?
+ );
+
+ // Test duplicate aggregate aliases
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(col("c9")).alias("count_c9_2"),
+ ])?;
+
+ assert_batches_eq!(
+ &[
+ "+----------+------------+",
+ "| count_c9 | count_c9_2 |",
+ "+----------+------------+",
+ "| 100 | 100 |",
+ "+----------+------------+",
+ ],
+ &res.collect().await?
+ );
+
+ // Wildcard
+ let res = df
+ .clone()
+ .select(vec![
+ SelectExpr::Wildcard(WildcardOptions::default()),
+ lit(42).into(),
+ ])?
+ .limit(0, None)?;
+
+ let batches = res.collect().await?;
+ assert_eq!(batches[0].num_rows(), 100);
+ assert_eq!(batches[0].num_columns(), 14);
Review Comment:
```suggestion
assert!(!batches.is_empty());
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 100);
assert!(batches.iter().all(|b| b.num_columns() == 14));
```
##########
datafusion/core/tests/dataframe/mod.rs:
##########
@@ -6854,3 +6857,74 @@ async fn
test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> {
+ let df = test_table().await?;
+
+ // Multiple aggregates
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"),
+ sum(col("c9")).alias("sum_c9"),
+ count(col("c8")).alias("count_c8"),
+ (sum(col("c9")) + count(col("c8"))).alias("total1"),
+ ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"),
+ (count(col("c9")) + lit(1)).alias("count_c9_add_1"),
+ ])?;
+
+ assert_batches_eq!(
+ &[
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1
| total2 | count_c9_add_1 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| 100 | 100 | 222089770060 | 100 |
222089770160 | 202 | 101 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ ],
+ &res.collect().await?
+ );
+
+ // Test duplicate aggregate aliases
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(col("c9")).alias("count_c9_2"),
+ ])?;
+
+ assert_batches_eq!(
+ &[
+ "+----------+------------+",
+ "| count_c9 | count_c9_2 |",
+ "+----------+------------+",
+ "| 100 | 100 |",
+ "+----------+------------+",
+ ],
+ &res.collect().await?
+ );
+
+ // Wildcard
+ let res = df
+ .clone()
+ .select(vec![
+ SelectExpr::Wildcard(WildcardOptions::default()),
+ lit(42).into(),
+ ])?
+ .limit(0, None)?;
+
+ let batches = res.collect().await?;
+ assert_eq!(batches[0].num_rows(), 100);
+ assert_eq!(batches[0].num_columns(), 14);
+
+ let res = df.clone().select(vec![
+ SelectExpr::QualifiedWildcard(
+ "aggregate_test_100".into(),
+ WildcardOptions::default(),
+ ),
+ lit(42).into(),
+ ])?;
+
+ let batches = res.collect().await?;
+ assert_eq!(batches[0].num_rows(), 100);
+ assert_eq!(batches[0].num_columns(), 14);
Review Comment:
```suggestion
assert!(!batches.is_empty());
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 100);
assert!(batches.iter().all(|b| b.num_columns() == 14));
```
##########
datafusion/core/tests/dataframe/mod.rs:
##########
@@ -6854,3 +6857,74 @@ async fn
test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> {
+ let df = test_table().await?;
+
+ // Multiple aggregates
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"),
+ sum(col("c9")).alias("sum_c9"),
+ count(col("c8")).alias("count_c8"),
+ (sum(col("c9")) + count(col("c8"))).alias("total1"),
+ ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"),
+ (count(col("c9")) + lit(1)).alias("count_c9_add_1"),
+ ])?;
+
+ assert_batches_eq!(
+ &[
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1
| total2 | count_c9_add_1 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ "| 100 | 100 | 222089770060 | 100 |
222089770160 | 202 | 101 |",
+
"+----------+--------------+--------------+----------+--------------+--------+----------------+",
+ ],
+ &res.collect().await?
+ );
+
+ // Test duplicate aggregate aliases
+ let res = df.clone().select(vec![
+ count(col("c9")).alias("count_c9"),
+ count(col("c9")).alias("count_c9_2"),
Review Comment:
```suggestion
count(col("c9")),
count(col("c9")),
```
let's remove the "manual" aliases here and assert that the logic at
https://github.com/apache/datafusion/pull/21021/changes#diff-997707d7dfcac94032b84a25bc0010c62209bf767e3abc6580a55a0a97c19de2R498
generates unique aliases.
##########
datafusion/core/src/dataframe/mod.rs:
##########
@@ -410,21 +412,102 @@ impl DataFrame {
expr_list: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<DataFrame> {
let expr_list: Vec<SelectExpr> =
- expr_list.into_iter().map(|e| e.into()).collect::<Vec<_>>();
+ expr_list.into_iter().map(|e| e.into()).collect();
+ // Extract expressions
let expressions = expr_list.iter().filter_map(|e| match e {
SelectExpr::Expression(expr) => Some(expr),
_ => None,
});
- let window_func_exprs = find_window_exprs(expressions);
- let plan = if window_func_exprs.is_empty() {
+ // Apply window functions first
+ let window_func_exprs = find_window_exprs(expressions.clone());
+
+ let mut plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
};
- let project_plan =
LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
+ // Collect aggregate expressions
+ let aggr_exprs = find_aggregate_exprs(expressions.clone());
+
+ // Check for non-aggregate expressions
+ let has_non_aggregate_expr = expressions
+ .clone()
+ .any(|expr|
find_aggregate_exprs(std::iter::once(expr)).is_empty());
+
+ // Fallback to projection:
+ // - already aggregated
+ // - contains non-aggregate expressions
+ // - no aggregates
+ if matches!(plan, LogicalPlan::Aggregate(_))
+ || has_non_aggregate_expr
+ || aggr_exprs.is_empty()
+ {
+ let project_plan =
+ LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
+
+ return Ok(DataFrame {
+ session_state: self.session_state,
+ plan: project_plan,
+ projection_requires_validation: false,
+ });
+ }
+
+ // Assign aliases to aggregate expressions
+ let mut aggr_map: HashMap<Expr, Expr> = HashMap::new();
+ let mut used_names = HashSet::new();
+ let aggr_exprs_with_alias: Vec<Expr> = aggr_exprs
+ .into_iter()
+ .map(|expr| {
+ let base_name = expr.name_for_alias()?;
+ let mut name = base_name.clone();
+ let mut counter = 1;
+ while used_names.contains(&name) {
+ name = format!("{base_name}_{counter}");
+ counter += 1;
+ }
+ used_names.insert(name.clone());
+ let aliased = expr.clone().alias(name.clone());
+ let col = Expr::Column(Column::from_name(name));
+ aggr_map.insert(expr, col);
+ Ok(aliased)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ // Build aggregate plan
+ plan = LogicalPlanBuilder::from(plan)
+ .aggregate(Vec::<Expr>::new(), aggr_exprs_with_alias)?
+ .build()?;
+
+ // Rewrite expressions to use aggregate outputs
+ let rewrite_expr = |expr: Expr, aggr_map: &HashMap<Expr, Expr>| ->
Result<Expr> {
+ expr.transform(|e| {
+ Ok(match aggr_map.get(&e) {
+ Some(replacement) => Transformed::yes(replacement.clone()),
+ None => Transformed::no(e),
+ })
+ })
+ .map(|t| t.data)
+ };
+
+ let mut rewritten_exprs = Vec::with_capacity(expr_list.len());
+ for select_expr in expr_list.into_iter() {
+ match select_expr {
+ SelectExpr::Expression(expr) => {
+ let rewritten = rewrite_expr(expr.clone(), &aggr_map)?;
+ let alias = expr.name_for_alias()?;
+
rewritten_exprs.push(SelectExpr::Expression(rewritten.alias(alias)));
Review Comment:
```suggestion
let alias = expr.name_for_alias()?;
let rewritten = rewrite_expr(expr, &aggr_map)?;
let final_expr = match &rewritten {
Expr::Alias(_) => rewritten,
_ => rewritten.alias(alias),
};
rewritten_exprs.push(SelectExpr::Expression(final_expr));
```
Only add alias if the rewritten expression doesn't already have one
##########
datafusion/core/src/dataframe/mod.rs:
##########
@@ -410,21 +412,102 @@ impl DataFrame {
expr_list: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<DataFrame> {
let expr_list: Vec<SelectExpr> =
- expr_list.into_iter().map(|e| e.into()).collect::<Vec<_>>();
+ expr_list.into_iter().map(|e| e.into()).collect();
+ // Extract expressions
let expressions = expr_list.iter().filter_map(|e| match e {
SelectExpr::Expression(expr) => Some(expr),
_ => None,
});
- let window_func_exprs = find_window_exprs(expressions);
- let plan = if window_func_exprs.is_empty() {
+ // Apply window functions first
+ let window_func_exprs = find_window_exprs(expressions.clone());
+
+ let mut plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
};
- let project_plan =
LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
+ // Collect aggregate expressions
+ let aggr_exprs = find_aggregate_exprs(expressions.clone());
+
+ // Check for non-aggregate expressions
+ let has_non_aggregate_expr = expressions
Review Comment:
`expressions` is filtered out above to contain on Expression items.
What if the original was `SELECT *, count(col("a")) ...` ?
The wildcard would have been dropped above and here and here
`has_non_aggregate_expr` would be false.
Please add an aggregate function to the test at
https://github.com/apache/datafusion/pull/21021/changes#diff-4a599584dfc900ec21169f4f820a1b1db46b004b77533dab83a6178d5d3a467eR6909
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]